knn(k近邻算法)——python

目录

1. 基本定义

2. 算法原理

2.1 算法优缺点

2.2 算法参数

2.3 变种

3.算法中的距离公式

4.案例实现

4.1 导入相关库

4.2 读取数据

4.3 读取变量名

4.4 定义X,Y数据 

4.5 分离训练集和测试集

4.6 计算欧式距离

4.7 可视化距离矩阵

4.8 预测样本

4.9 查看正确率

4.10 交叉验证

5. scikit-learn的算法实现

5.1 对上述的再次实现:

5.2 另一种实现方式


1. 基本定义

        k最近邻(k-Nearest Ne ighbor)算法是比较简单的机器学习算法。它采用测量不同特征值之间的距离方法进行分类。它的思想很简单:如果一个样本在特征空间中的多个最近邻(最相似〉的样本中的大多数都属于某一个类别,则该样本也属于这个类别。第一个字母k可以小写,表示外部定义的近邻数量。

        简而言之,就是让机器自己按照每一个点的距离,距离近的为一类。

2. 算法原理

        knn算法的核心思想是未标记样本的类别,由距离其最近的k个邻居投票来决定。
        具体的,假设我们有一个已标记好的数据集。此时有一个未标记的数据样本,我们的任务是预测出这个数据样本所属的类别。knn的原理是,计算待标记样本和数据集中每个样本的距离,取距离最近的k个样本。待标记的样本所属类别就由这k个距离最近的样本投票产生。
假设X_test为待标记的样本,X_train为已标记的数据集,算法原理的伪代码如下:

  1. 遍历X_train中的所有样本,计算每个样本与X_test的距离,并把距离保存在Distance数组中。
  2. 对Distance数组进行排序,取距离最近的k个点,记为X_knn。
  3. 在X_knn中统计每个类别的个数,即class0在X_knn中有几个样本,class1在X_knn中有几个样本等。
  4. 待标记样本的类别,就是在X_knn中样本个数最多的那个类别。

2.1 算法优缺点

  • 优点:准确性高,对异常值和噪声有较高的容忍度。
  • 缺点:计算量较大,对内存的需求也较大。

2.2 算法参数

        其算法参数是k,参数选择需要根据数据来决定。

  • k值越大,模型的偏差越大,对噪声数据越不敏感,当k值很大时,可能造成欠拟合;
  • k值越小,模型的方差就会越大,当k值太小,就会造成过拟合。

2.3 变种

        knn算法有一些变种,其中之一是可以增加邻居的权重。默认情况下,在计算距离时,都是使用相同权重。实际上,可以针对不同的邻居指定不同的距离权重,如距离越近权重越高。这个可以通过指定算法的weights参数来实现。
        另一个变种是,使用一定半径内的点取代距离最近的k个点。当数据采样不均匀时,可以有更好的性能。在scikit-learn里,RadiusNeighborsClassifier类实现了这个算法变种。

3.算法中的距离公式

        与我们的线性回归不同,在这里我们并没有什么公式可以进行推导。KNN分类算法的核心就在于计算距离,随后按照距离分类。

    在二维笛卡尔坐标系,相信初中同学应该对这个应该不陌生,他有一个更加常见的名字,直角坐标系。其中,计算两个点之间的距离公式,常用的有欧氏距离。点A(2,3),点B(5,6),那么AB的距离为                        

    这,便是欧氏距离。但和我们平常经常遇到的还是有一些区别的,欧氏距离是可以计算多维数据的,也就是矩阵(Matrix)。这可以帮我们解决很多问题,那么公式也就变成了

4.案例实现

我们使用knn算法及其变种,对Pina印第安人的糖尿病进行预测。数据集可从下面下载。
链接:蓝奏云

4.1 导入相关库

# 导入相关模块
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
import pandas as pd

4.2 读取数据

#读取数据
data=pd.read_excel('D:\桌面\knn.xlsx')
print(data)

返回:

4.3 读取变量名

label_need=data.keys()
print(label_need)

返回:

4.4 定义X,Y数据 

X = data[label_need].values[:,0:8]
y = data[label_need].values[:,8]
print(X)
print(y)

返回:

4.5 分离训练集和测试集

from sklearn.model_selection import train_test_split
X_train, X_test, y_train,y_test = train_test_split(X, y, test_size=0.2)# 打印训练集和测试集大小
print('X_train=', X_train.shape)
print('X_test=', X_test.shape)
print('y_train=', y_train.shape)
print('y_test=', y_test.shape)

返回:

 

4.6 计算欧式距离

# 测试实例样本量
num_test = X.shape[0]
# 训练实例样本量
num_train = X_train.shape[0]
# 基于训练和测试维度的欧氏距离初始化
dists = np.zeros((num_test, num_train)) 
# 测试样本与训练样本的矩阵点乘
M = np.dot(X, X_train.T)
# 测试样本矩阵平方
te = np.square(X).sum(axis=1)
# 训练样本矩阵平方
tr = np.square(X_train).sum(axis=1)
# 计算欧式距离
dists = np.sqrt(-2 * M + tr + np.matrix(te).T) 
print(dists)

返回:

4.7 可视化距离矩阵

dists = compute_distances(X_test, X_train)
plt.imshow(dists, interpolation='none')
plt.show()

返回:

4.8 预测样本

# 测试样本量
num_test = dists.shape[0]
# 初始化测试集预测结果
y_pred = np.zeros(num_test) 
# 遍历   
for i in range(num_test):# 初始化最近邻列表closest_y = []# 按欧氏距离矩阵排序后取索引,并用训练集标签按排序后的索引取值
# 最后拉平列表
# 注意np.argsort函数的用法labels = y_train[np.argsort(dists[i, :])].flatten()# 取最近的k个值closest_y = labels[0:k]# 对最近的k个值进行计数统计# 这里注意collections模块中的计数器Counter的用法c = Counter(closest_y)# 取计数最多的那一个类别y_pred[i] = c.most_common(1)[0][0] 
print(y_pred)

返回:

4.9 查看正确率

查看实际和预测相符的个数:

# 找出预测正确的实例
num_correct = np.sum(y_test_pred == y_test)
print(num_correct)

返回:

计算正确率:

# 计算准确率
accuracy = float(num_correct) / X_test.shape[0]
print('Got %d/%d correct=>accuracy:%f'% (num_correct, X_test.shape[0], accuracy))

返回:

4.10 交叉验证

# 折交叉验证
num_folds = 5
# 候选k值
k_choices = [1, 3, 5, 8, 10, 12, 15, 20, 50, 100]
X_train_folds = []
y_train_folds = []
# 训练数据划分
X_train_folds = np.array_split(X_train, num_folds)
# 训练标签划分
y_train_folds = np.array_split(y_train, num_folds)
k_to_accuracies = {}
# 遍历所有候选k值
for k in k_choices:# 五折遍历    for fold in range(num_folds): # 对传入的训练集单独划出一个验证集作为测试集validation_X_test = X_train_folds[fold]validation_y_test = y_train_folds[fold]temp_X_train = np.concatenate(X_train_folds[:fold] + X_train_folds[fold + 1:])temp_y_train = np.concatenate(y_train_folds[:fold] + y_train_folds[fold + 1:])       # 计算距离temp_dists = compute_distances(validation_X_test, temp_X_train)temp_y_test_pred = predict_labels(temp_y_train, temp_dists, k=k)temp_y_test_pred = temp_y_test_pred.reshape((-1, 1))       # 查看分类准确率num_correct = np.sum(temp_y_test_pred == validation_y_test)num_test = validation_X_test.shape[0]accuracy = float(num_correct) / num_testk_to_accuracies[k] = k_to_accuracies.get(k,[]) + [accuracy]

打印不同 k 值不同折数下的分类准确率:

# 打印不同 k 值不同折数下的分类准确率
for k in sorted(k_to_accuracies):    for accuracy in k_to_accuracies[k]:print('k = %d, accuracy = %f' % (k, accuracy))

返回:

不同 k 值不同折数下的分类准确率的可视化:

for k in k_choices:# 取出第k个k值的分类准确率accuracies = k_to_accuracies[k]# 绘制不同k值准确率的散点图plt.scatter([k] * len(accuracies), accuracies)
# 计算准确率均值并排序
accuracies_mean = np.array([np.mean(v) for k,v in sorted(k_to_accuracies.items())])
# 计算准确率标准差并排序
accuracies_std = np.array([np.std(v) for k,v in sorted(k_to_accuracies.items())])
# 绘制有置信区间的误差棒图
plt.errorbar(k_choices, accuracies_mean, yerr=accuracies_std)
# 绘图标题
plt.title('Cross-validation on k')
# x轴标签
plt.xlabel('k')
# y轴标签
plt.ylabel('Cross-validation accuracy')
plt.show()

返回:

5. scikit-learn的算法实现

5.1 对上述的再次实现:

# 导入KneighborsClassifier模块
from sklearn.neighbors import KNeighborsClassifier
# 创建k近邻实例
neigh = KNeighborsClassifier(n_neighbors=10)
# k近邻模型拟合
neigh.fit(X_train, y_train)
# k近邻模型预测
y_pred = neigh.predict(X_test)
# # 预测结果数组重塑
# y_pred = y_pred.reshape((-1, 1))
# 统计预测正确的个数
num_correct = np.sum(y_pred == y_test)
print(num_correct)
# 计算准确率
accuracy = float(num_correct) / X_test.shape[0]
print('Got %d / %d correct => accuracy: %f' % (num_correct, X_test.shape[0], accuracy))

返回:

5.2 另一种实现方式

5.2.1 加载数据

import pandas as pd
data = pd.read_csv('D:\桌面\knn.csv')
print('dataset shape {}'.format(data.shape))
data.info()

返回:

5.2.2 分离训练集和测试集

X = data.iloc[:, 0:8]
Y = data.iloc[:, 8]
print('shape of X {}, shape of Y {}'.format(X.shape, Y.shape))from sklearn.model_selection import train_test_split
X_train, X_test, Y_train,Y_test = train_test_split(X, Y, test_size=0.2)

返回:

5.2.3 模型比较

使用普通的knn算法、带权重的knn以及指定半径的knn算法分别对数据集进行拟合并计算评分

from sklearn.neighbors import KNeighborsClassifier, RadiusNeighborsClassifier# 构建3个模型
models = []
models.append(('KNN', KNeighborsClassifier(n_neighbors=2)))
models.append(('KNN with weights', KNeighborsClassifier(n_neighbors=2, weights='distance')))
models.append(('Radius Neighbors', RadiusNeighborsClassifier(n_neighbors=2, radius=500.0)))# 分别训练3个模型,并计算得分
results = []
for name, model in models:model.fit(X_train, Y_train)results.append((name, model.score(X_test, Y_test)))
for i in range(len(results)):print('name: {}; score: {}'.format(results[i][0], results[i][1]))

返回:

权重算法,我们选择了距离越近,权重越高。RadiusNeighborsClassifier模型的半径选择了500.从输出可以看出,普通的knn算法还是最好。

问题来了,这个判断准确吗? 答案是:不准确。

因为我们的训练集和测试集是随机分配的,不同的训练样本和测试样本组合可能导致计算出来的算法准确性有差异。

那么该如何解决呢?

我们可以多次随机分配训练集和交叉验证集,然后求模型评分的平均值。

scikit-learn提供了KFold和cross_val_score()函数来处理这种问题。

from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_scoreresults = []
for name, model in models:kfold = KFold(n_splits=10)cv_result = cross_val_score(model, X, Y, cv=kfold)results.append((name, cv_result))for i in range(len(results)):print('name: {}; cross_val_score: {}'.format(results[i][0], results[i][1].mean()))

返回:

上述代码,我们通过KFold把数据集分成10份,其中1份会作为交叉验证集来计算模型准确性,剩余9份作为训练集。cross_val_score()函数总共计算出10次不同训练集和交叉验证集组合得到的模型评分,最后求平均值。 看起来,还是普通的knn算法性能更优一些。 

5.2.4 模型训练及分析 

据上面模型比较得到的结论,我们接下来使用普通的knn算法模型对数据集进行训练,并查看对训练样本的拟合情况以及对测试样本的预测准确性情况:

knn = KNeighborsClassifier(n_neighbors=2)
knn.fit(X_train, Y_train)
train_score = knn.score(X_train, Y_train)
test_score = knn.score(X_test, Y_test)
print('train score: {}; test score : {}'.format(train_score, test_score))

返回:

从这里可以看到两个问题。

  • 对训练样本的拟合情况不佳,评分才0.84多一些,说明算法模型太简单了,无法很好地拟合训练样本。
  • 模型准确性不好,0.66左右的预测准确性。

我们画出曲线,查看一下。

我们首先定义一下这个画图函数,代码如下:

from sklearn.model_selection import learning_curve
import numpy as npdef plot_learning_curve(plt, estimator, title, X, y, ylim=None, cv=None,n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)):"""Generate a simple plot of the test and training learning curve.Parameters----------estimator : object type that implements the "fit" and "predict" methodsAn object of that type which is cloned for each validation.title : stringTitle for the chart.X : array-like, shape (n_samples, n_features)Training vector, where n_samples is the number of samples andn_features is the number of features.y : array-like, shape (n_samples) or (n_samples, n_features), optionalTarget relative to X for classification or regression;None for unsupervised learning.ylim : tuple, shape (ymin, ymax), optionalDefines minimum and maximum yvalues plotted.cv : int, cross-validation generator or an iterable, optionalDetermines the cross-validation splitting strategy.Possible inputs for cv are:- None, to use the default 3-fold cross-validation,- integer, to specify the number of folds.- An object to be used as a cross-validation generator.- An iterable yielding train/test splits.For integer/None inputs, if ``y`` is binary or multiclass,:class:`StratifiedKFold` used. If the estimator is not a classifieror if ``y`` is neither binary nor multiclass, :class:`KFold` is used.Refer :ref:`User Guide <cross_validation>` for the variouscross-validators that can be used here.n_jobs : integer, optionalNumber of jobs to run in parallel (default 1)."""plt.title(title)if ylim is not None:plt.ylim(*ylim)plt.xlabel("Training examples")plt.ylabel("Score")train_sizes, train_scores, test_scores = learning_curve(estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)train_scores_mean = np.mean(train_scores, axis=1)train_scores_std = np.std(train_scores, axis=1)test_scores_mean = np.mean(test_scores, axis=1)test_scores_std = np.std(test_scores, axis=1)plt.grid()plt.fill_between(train_sizes, train_scores_mean - train_scores_std,train_scores_mean + train_scores_std, alpha=0.1,color="r")plt.fill_between(train_sizes, test_scores_mean - test_scores_std,test_scores_mean + test_scores_std, alpha=0.1, color="g")plt.plot(train_sizes, train_scores_mean, 'o--', color="r",label="Training score")plt.plot(train_sizes, test_scores_mean, 'o-', color="g",label="Cross-validation score")plt.legend(loc="best")return plt

然后我们调用这个函数画一下图看看:

from sklearn.model_selection import ShuffleSplitknn = KNeighborsClassifier(n_neighbors=2)
cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)
plt.figure(figsize=(10,6), dpi=200)
plot_learning_curve(plt, knn, 'Learn Curve for KNN Diabetes', X, Y, ylim=(0.0, 1.01), cv=cv)

返回:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/564980.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言中的文件是什么?

我们对文件的概念已经非常熟悉了&#xff0c;比如常见的 Word 文档、txt 文件、源文件等。文件是数据源的一种&#xff0c;最主要的作用是保存数据。 在操作系统中&#xff0c;为了统一对各种硬件的操作&#xff0c;简化接口&#xff0c;不同的硬件设备也都被看成一个文件。对…

knn(k近邻算法)——matlab

目录 1. 基本定义 2. 算法原理 2.1 算法优缺点 2.2 算法参数 2.3 变种 3.算法中的距离公式 4.案例实现 4.1 读取数据 4.2 分离训练集和测试集 4.3 归一化处理 4.4 计算欧氏距离 4.5 排序和输出测试结果 4.6 计算准确率 总代码 1. 基本定义 k最近邻(k-Nearest N…

C语言打开文件详解

C语言中操作文件之前必须先打开文件&#xff1b;所谓“打开文件”&#xff0c;就是让程序和文件建立连接的过程。 打开文件之后&#xff0c;程序可以得到文件的相关信息&#xff0c;例如大小、类型、权限、创建者、更新时间等。在后续读写文件的过程中&#xff0c;程序还可以记…

python turtle虎年来拜年了

1.画个虎 # codingutf-8 from turtle import * import timeCOLOR #B2814Ddef set_start(x, y, w, cCOLOR):penup()setx(x)sety(y)setheading(towards(0, 0))width(w)pencolor(c)pendown()speed(0)def left_rotate(time, angle, length):for i in range(time):left(angle)forwa…

TOPSIS法 —— matlab

目录 1.TOPSIS法介绍 2. 计算步骤 &#xff08;1&#xff09;数据标准化 &#xff08;2&#xff09;得到加权后的矩阵 &#xff08;3&#xff09;确定正理想解和负理想解 &#xff08;4&#xff09;计算各方案到正&#xff08;负&#xff09;理想解的距离 &#xff08;…

TOPSIS法 —— python

目录 1.TOPSIS法介绍 2. 计算步骤 &#xff08;1&#xff09;数据标准化 &#xff08;2&#xff09;得到加权后的矩阵 &#xff08;3&#xff09;确定正理想解和负理想解 &#xff08;4&#xff09;计算各方案到正&#xff08;负&#xff09;理想解的距离 &#xff08;…

C语言随机读写文件

实现随机读写的关键是要按要求移动位置指针&#xff0c;这称为文件的定位。 文件定位函数rewind和fseek 移动文件内部位置指针的函数主要有两个&#xff0c;即 rewind() 和 fseek()。 rewind() 用来将位置指针移动到文件开头&#xff0c;前面已经多次使用过&#xff0c;它的…

mysql-installer安装教程(详细图文)

目录 1.安装 2.配置系统环境变量 3.配置初始化my.ini文件 4.MySQL彻底删除 5.Navicat 安装 1.安装 先去官网下载需要的msi&#xff0c;在这放出官网下载地址下载地址 这里我具体以8.0.28 为安装例子&#xff0c;除了最新版安装界面有些变动以往的都是差不多的。 过去的版本…

Java三种随机数生成方法

java的三种随机数生成方式 随机数的产生在一些代码中很常用&#xff0c;也是我们必须要掌握的。而java中产生随机数的方法主要有三种&#xff1a;     第一种&#xff1a;new Random()     第二种&#xff1a;Math.random()     第三种&#xff1a;currentTimeMil…

Python MySQL入门连接

目录 基本环境准备 navicat的傻瓜使用方式 python连接 mysql安装教程&#xff1a;传送门 基本环境准备 WINR 输入cmd回车打开cmd&#xff0c;登录mysql: mysql -h localhost -u root -p然后输入密码回车即可。 创建用户名为testuser1&#xff1a; CREATE USER testuser1…

JDBC连接sql server数据库

IDEA使用JDBC连接Sqlserver数据库 在IDEA的项目中添加对应数据库的jar包 在项目中创建util包和DBUtil类用来存放数据库连接的java代码。 完整代码 package com.hnpi.util;import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; …

Python MySQL创建表

目录 一、创建表 二、检查表是否存在 三、关键字 一、创建表 在库student环境下创建表名为stu: # codinggbk #连接 import pymysqlmydb pymysql.connect(host"localhost", #默认用主机名port3306,user"root", #默认用户名password"123456"…

JDBC连接 Mysql数据库

IDEA使用JDBC连接Mysql数据库 在项目中添加连接Mysql数据库的jar包 在项目中创建util包和DBUtil类用来存放数据库连接的java代码。 完整代码 package com.zsh.util;import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; import…

Python MySQL插入表

目录 1.插入表格 2.插入多行 3.获取插入的 ID 1.插入表格 要在 MySQL 中填充表&#xff0c;请使用“INSERT INTO”语句。 “stu”表中添加一条记录&#xff1a; 代码&#xff1a; # codinggbk #连接 import pymysqlmydb pymysql.connect(host"localhost", #默认…

Python MySQL选择

目录 1.从表中选择 2.选择列 3.使用 fetchone() 1.从表中选择 要从 MySQL 中的表中进行选择&#xff0c;请使用“SELECT”语句。从“stu”表中选择所有记录&#xff0c;并显示结果&#xff1a; # codinggbk #连接 import pymysqlmydb pymysql.connect(host"localhos…

Eclipse编辑器字体大小的设置

我们在第一次使用 Eclipse 编写程序时&#xff0c;由于 Eclipse 默认使用的是 Cosnolas 字体&#xff0c;字号为 10&#xff0c;所以编辑器中的字体非常小&#xff0c;不方便查看。 我们可以通过下面所示的方法来修改编辑器的字体大小。 操作方法&#xff1a; 1 . 选择“窗口…

Python MySQL查询在哪里(where)

目录 一.用过滤器选择 二.通配符 三.防止 SQL 注入 一.用过滤器选择 从表中选择记录时&#xff0c;可以使用“WHERE”语句过滤选择。例如&#xff1a;选择名字为”笨小孩“的记录&#xff1a;结果&#xff1a; # codinggbk #连接 import pymysqlmydb pymysql.connect(hos…

Python MySQL排序

目录 顺序排序 按 DESC逆序排序 顺序排序 使用 ORDER BY 语句按升序或降序对结果进行排序。ORDER BY 关键字默认对结果进行升序排序。要按降序对结果进行排序&#xff0c;请使用 DESC 关键字。 按名称的字母顺序对结果进行排序&#xff1a; # codinggbk #连接 import pymys…

Python MySQL删除

目录 删除记录 防止 SQL 注入 删除记录 您可以使用“DELETE FROM”语句从现有表中删除记录。例如删除地址为“笨小孩”的任何记录&#xff1a; # codinggbk #连接 import pymysqlmydb pymysql.connect(host"localhost", #默认用主机名port3306,user"root&q…

Python MySQL更新表

目录 更新表 防止 SQL 注入 更新表 您可以使用“UPDATE”语句更新表中的现有记录。将地址栏从“Valley 345”改写为“Canyoun 123”&#xff1a; # codinggbk #连接 import pymysqlmydb pymysql.connect(host"localhost", #默认用主机名port3306,user"root…