文章目录
- 一、kneighborsclassifier是什么?
- 二、使用步骤
- 三、kneighborsclassifier函数及其参数详解
- 1. 参数说明
一、kneighborsclassifier是什么?
kneighborsclassifier
是 scikit-learn
库中 K-近邻算法的实现,用于分类任务。KNN 算法的基本思想是给定一个样本数据集,对于每个输入的新数据点,找到其在样本数据集中最近的 K 个数据点,根据这 K 个邻居的类别来预测新数据点的类别。
二、使用步骤
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris# 载入数据
iris = load_iris()
X = iris.data
y = iris.target# 分割数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler()
scaler.fit(X_train)X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)
# 初始化KNN分类器
knn = KNeighborsClassifier(n_neighbors=5)# 训练模型
knn.fit(X_train, y_train)
# 预测测试集
y_pred = knn.predict(X_test)# 分类报告
print(classification_report(y_test, y_pred))# 可视化混淆矩阵
confusion = confusion_matrix(y_test, y_pred)
plt.matshow(confusion)
# 设置中文
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 在每个单元格中添加数字
for i in range(confusion.shape[0]):for j in range(confusion.shape[1]):plt.text(x=j, y=i, s=str(confusion[i, j]), va='center', ha='center', color='red')plt.colorbar()
plt.ylabel('实际类型')
plt.xlabel('预测类型')
plt.title('混淆矩阵')
plt.show()
三、kneighborsclassifier函数及其参数详解
1. 参数说明
n_neighbors
: 用于指定邻居的数目,默认值是 5。weights
: 用于确定邻居对预测的贡献。可以是 ‘uniform’(默认值,表示所有邻居的权重相同),‘distance’(邻居的权重与距离成反比),或者用户自定义的权重函数。algorithm
: 用于计算最近邻居的算法。可选值有 ‘auto’(默认值,根据数据选择最佳算法),‘ball_tree’,‘kd_tree’,以及 ‘brute’。leaf_size
: 用于指定 BallTree 或 KDTree 中叶节点的大小,默认值是 30。影响树的构建和查询速度。p
: 用于指定距离度量的方法。p=2 是欧氏距离,p=1 是曼哈顿距离。metric
: 用于指定距离度量,默认值是 ‘minkowski’。metric_params
: 用于指定距离度量的附加参数,默认是 None。n_jobs
: 用于指定并行运行的作业数量。-1 表示使用所有的处理器。