K近邻分类器(KNN)(4-2)
K近邻分类器(K-Nearest Neighbor,简称KNN)是一种基本的机器学习分类算法。它的工作原理是:在特征空间中,如果一个样本在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别。
具体来说,KNN算法首先计算待分类样本与其他所有样本的距离,然后按照距离的递增关系进行排序,选取距离最小的K个样本,最后根据这K个样本的类别通过多数投票等方式进行预测。当K=1时,KNN算法又称为最近邻算法。
KNN算法的优点包括:
- 思想简单,易于理解和实现。
- 对数据分布没有假设,完全基于距离度量进行分类。
- 适用范围广,可以用于多分类问题。
然而,KNN算法也存在一些缺点:
- 对距离度量函数和K值的选择敏感,不同的距离度量函数和K值可能会产生不同的分类结果。
- 计算量大,需要计算待分类样本与所有训练样本的距离。
- 内存需求大,需要存储所有的训练样本。
- 可解释性不强,无法给出决策边界等直观的解释。
KNN算法的应用场景非常广泛,包括但不限于:
- 垃圾邮件识别:可以将邮件分为“垃圾邮件”或“正常邮件”两类。
- 图像内容识别:由于图像的内容种类可能很多,因此这是一个多类分类问题。
- 文本情感分析:既可以作为二分类问题(褒贬两种情感),也可以作为多类分类问题(如十分消极、消极、积极、十分积极等)。
此外,KNN算法还可以用于其他机器学习任务,如手写数字识别、鸢尾花分类等。在这些任务中,KNN算法都表现出了较好的性能。
- 数据实例
ID | Age | Experience | Income | ZIP Code | Family | CCAvg | Education | Mortgage | Personal Loan | Securities Account | CD Account | Online | CreditCard |
1 | 25 | 1 | 49 | 91107 | 4 | 1.6 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
2 | 45 | 19 | 34 | 90089 | 3 | 1.5 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
3 | 39 | 15 | 11 | 94720 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 35 | 9 | 100 | 94112 | 1 | 2.7 | 2 | 0 | 0 | 0 | 0 | 0 | 0 |
5 | 35 | 8 | 45 | 91330 | 4 | 1 | 2 | 0 | 0 | 0 | 0 | 0 | 1 |
6 | 37 | 13 | 29 | 92121 | 4 | 0.4 | 2 | 155 | 0 | 0 | 0 | 1 | 0 |
7 | 53 | 27 | 72 | 91711 | 2 | 1.5 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
8 | 50 | 24 | 22 | 93943 | 1 | 0.3 | 3 | 0 | 0 | 0 | 0 | 0 | 1 |
9 | 35 | 10 | 81 | 90089 | 3 | 0.6 | 2 | 104 | 0 | 0 | 0 | 1 | 0 |
10 | 34 | 9 | 180 | 93023 | 1 | 8.9 | 3 | 0 | 1 | 0 | 0 | 0 | 0 |
11 | 65 | 39 | 105 | 94710 | 4 | 2.4 | 3 | 0 | 0 | 0 | 0 | 0 | 0 |
12 | 29 | 5 | 45 | 90277 | 3 | 0.1 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
13 | 48 | 23 | 114 | 93106 | 2 | 3.8 | 3 | 0 | 0 | 1 | 0 | 0 | 0 |
14 | 59 | 32 | 40 | 94920 | 4 | 2.5 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
15 | 67 | 41 | 112 | 91741 | 1 | 2 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
16 | 60 | 30 | 22 | 95054 | 1 | 1.5 | 3 | 0 | 0 | 0 | 0 | 1 | 1 |
17 | 38 | 14 | 130 | 95010 | 4 | 4.7 | 3 | 134 | 1 | 0 | 0 | 0 | 0 |
18 | 42 | 18 | 81 | 94305 | 4 | 2.4 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
19 | 46 | 21 | 193 | 91604 | 2 | 8.1 | 3 | 0 | 1 | 0 | 0 | 0 | 0 |
20 | 55 | 28 | 21 | 94720 | 1 | 0.5 | 2 | 0 | 0 | 1 | 0 | 0 | 1 |
21 | 56 | 31 | 25 | 94015 | 4 | 0.9 | 2 | 111 | 0 | 0 | 0 | 1 | 0 |
22 | 57 | 27 | 63 | 90095 | 3 | 2 | 3 | 0 | 0 | 0 | 0 | 1 | 0 |
23 | 29 | 5 | 62 | 90277 | 1 | 1.2 | 1 | 260 | 0 | 0 | 0 | 1 | 0 |
24 | 44 | 18 | 43 | 91320 | 2 | 0.7 | 1 | 163 | 0 | 1 | 0 | 0 | 0 |
25 | 36 | 11 | 152 | 95521 | 2 | 3.9 | 1 | 159 | 0 | 0 | 0 | 0 | 1 |
26 | 43 | 19 | 29 | 94305 | 3 | 0.5 | 1 | 97 | 0 | 0 | 0 | 1 | 0 |
27 | 40 | 16 | 83 | 95064 | 4 | 0.2 | 3 | 0 | 0 | 0 | 0 | 0 | 0 |
28 | 46 | 20 | 158 | 90064 | 1 | 2.4 | 1 | 0 | 0 | 0 | 0 | 1 | 1 |
29 | 56 | 30 | 48 | 94539 | 1 | 2.2 | 3 | 0 | 0 | 0 | 0 | 1 | 1 |
30 | 38 | 13 | 119 | 94104 | 1 | 3.3 | 2 | 0 | 1 | 0 | 1 | 1 | 1 |
31 | 59 | 35 | 35 | 93106 | 1 | 1.2 | 3 | 122 | 0 | 0 | 0 | 1 | 0 |
32 | 40 | 16 | 29 | 94117 | 1 | 2 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
33 | 53 | 28 | 41 | 94801 | 2 | 0.6 | 3 | 193 | 0 | 0 | 0 | 0 | 0 |
34 | 30 | 6 | 18 | 91330 | 3 | 0.9 | 3 | 0 | 0 | 0 | 0 | 0 | 0 |
35 | 31 | 5 | 50 | 94035 | 4 | 1.8 | 3 | 0 | 0 | 0 | 0 | 1 | 0 |
36 | 48 | 24 | 81 | 92647 | 3 | 0.7 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
37 | 59 | 35 | 121 | 94720 | 1 | 2.9 | 1 | 0 | 0 | 0 | 0 | 0 | 1 |
38 | 51 | 25 | 71 | 95814 | 1 | 1.4 | 3 | 198 | 0 | 0 | 0 | 0 | 0 |
39 | 42 | 18 | 141 | 94114 | 3 | 5 | 3 | 0 | 1 | 1 | 1 | 1 | 0 |
40 | 38 | 13 | 80 | 94115 | 4 | 0.7 | 3 | 285 | 0 | 0 | 0 | 1 | 0 |
41 | 57 | 32 | 84 | 92672 | 3 | 1.6 | 3 | 0 | 0 | 1 | 0 | 0 | 0 |
42 | 34 | 9 | 60 | 94122 | 3 | 2.3 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
43 | 32 | 7 | 132 | 90019 | 4 | 1.1 | 2 | 412 | 1 | 0 | 0 | 1 | 0 |
44 | 39 | 15 | 45 | 95616 | 1 | 0.7 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
45 | 46 | 20 | 104 | 94065 | 1 | 5.7 | 1 | 0 | 0 | 0 | 0 | 1 | 1 |
46 | 57 | 31 | 52 | 94720 | 4 | 2.5 | 1 | 0 | 0 | 0 | 0 | 0 | 1 |
47 | 39 | 14 | 43 | 95014 | 3 | 0.7 | 2 | 153 | 0 | 0 | 0 | 1 | 0 |
48 | 37 | 12 | 194 | 91380 | 4 | 0.2 | 3 | 211 | 1 | 1 | 1 | 1 | 1 |
49 | 56 | 26 | 81 | 95747 | 2 | 4.5 | 3 | 0 | 0 | 0 | 0 | 0 | 1 |
50 | 40 | 16 | 49 | 92373 | 1 | 1.8 | 1 | 0 | 0 | 0 | 0 | 0 | 1 |
51 | 32 | 8 | 8 | 92093 | 4 | 0.7 | 2 | 0 | 0 | 1 | 0 | 1 | 0 |
52 | 61 | 37 | 131 | 94720 | 1 | 2.9 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
53 | 30 | 6 | 72 | 94005 | 1 | 0.1 | 1 | 207 | 0 | 0 | 0 | 0 | 0 |
54 | 50 | 26 | 190 | 90245 | 3 | 2.1 | 3 | 240 | 1 | 0 | 0 | 1 | 0 |
55 | 29 | 5 | 44 | 95819 | 1 | 0.2 | 3 | 0 | 0 | 0 | 0 | 1 | 0 |
56 | 41 | 17 | 139 | 94022 | 2 | 8 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
57 | 55 | 30 | 29 | 94005 | 3 | 0.1 | 2 | 0 | 0 | 1 | 1 | 1 | 0 |
58 | 56 | 31 | 131 | 95616 | 2 | 1.2 | 3 | 0 | 1 | 0 | 0 | 0 | 0 |
59 | 28 | 2 | 93 | 94065 | 2 | 0.2 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
60 | 31 | 5 | 188 | 91320 | 2 | 4.5 | 1 | 455 | 0 | 0 | 0 | 0 | 0 |
61 | 49 | 24 | 39 | 90404 | 3 | 1.7 | 2 | 0 | 0 | 1 | 0 | 1 | 0 |
62 | 47 | 21 | 125 | 93407 | 1 | 5.7 | 1 | 112 | 0 | 1 | 0 | 0 | 0 |
63 | 42 | 18 | 22 | 90089 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
64 | 42 | 17 | 32 | 94523 | 4 | 0 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
65 | 47 | 23 | 105 | 90024 | 2 | 3.3 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
66 | 59 | 35 | 131 | 91360 | 1 | 3.8 | 1 | 0 | 0 | 0 | 0 | 1 | 1 |
67 | 62 | 36 | 105 | 95670 | 2 | 2.8 | 1 | 336 | 0 | 0 | 0 | 0 | 0 |
68 | 53 | 23 | 45 | 95123 | 4 | 2 | 3 | 132 | 0 | 1 | 0 | 0 | 0 |
69 | 47 | 21 | 60 | 93407 | 3 | 2.1 | 1 | 0 | 0 | 0 | 0 | 1 | 1 |
70 | 53 | 29 | 20 | 90045 | 4 | 0.2 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
71 | 42 | 18 | 115 | 91335 | 1 | 3.5 | 1 | 0 | 0 | 0 | 0 | 0 | 1 |
72 | 53 | 29 | 69 | 93907 | 4 | 1 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
73 | 44 | 20 | 130 | 92007 | 1 | 5 | 1 | 0 | 0 | 0 | 0 | 0 | 1 |
74 | 41 | 16 | 85 | 94606 | 1 | 4 | 3 | 0 | 0 | 0 | 0 | 1 | 1 |
75 | 28 | 3 | 135 | 94611 | 2 | 3.3 | 1 | 0 | 0 | 0 | 0 | 0 | 1 |
76 | 31 | 7 | 135 | 94901 | 4 | 3.8 | 2 | 0 | 1 | 0 | 1 | 1 | 1 |
使用第1题中的Universal Bank数据集。
注意:数据集中的编号(ID)和邮政编码(ZIP CODE)特征因为在分类模型中无意义,所以在数据预处理阶段将它们删除。
- 使用KNN对数据进行分类
- 使用留出法划分数据集,训练集:测试集为7:3。
# 使用留出法划分数据集,训练集:测试集为7:3
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
- 使用KNN对训练集进行训练
# 使用KNN算法对训练集进行训练,最近邻的数量K设置为5
model = KNeighborsClassifier(n_neighbors=5)
model.fit(X_train, y_train)
最近邻的数量K设置为5。
- 使用训练好的模型对测试集进行预测并输出预测结果和模型准确度
# 使用训练好的模型对测试集进行预测
y_pred = model.predict(X_test)# 输出预测结果
for item in y_pred:print(item, end='\n') # 每项后面都换行,这样就不会合并在一起
print("预测结果:")
print(y_pred)# 输出模型准确度
accuracy = accuracy_score(y_test, y_pred)
print("模型准确度:", accuracy)
完整代码:
# 导入所需的库
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import pprint# 禁用输出省略
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)# 读取数据集
data = pd.read_csv("universalbank.csv")# 数据预处理:删除无意义特征
data = data.drop(columns=['ID', 'ZIP Code'])# 划分特征和标签
X = data.drop(columns=['Personal Loan'])
y = data['Personal Loan']# 使用留出法划分数据集,训练集:测试集为7:3
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 使用KNN算法对训练集进行训练,最近邻的数量K设置为5
model = KNeighborsClassifier(n_neighbors=5)
model.fit(X_train, y_train)# 使用训练好的模型对测试集进行预测
y_pred = model.predict(X_test)# 输出预测结果
for item in y_pred:print(item, end='\n') # 每项后面都换行,这样就不会合并在一起
print("预测结果:")
print(y_pred)# 输出模型准确度
accuracy = accuracy_score(y_test, y_pred)
print("模型准确度:", accuracy)