利用KNN算法实现手写数字识别
MNIST手写数字识别 是计算机视觉领域中 "hello world"级别的数据集
- 1999年发布,成为分类算法基准测试的基础
- 随着新的机器学习技术的出现,MNIST仍然是研究人员和学习者的可靠资源。
本次案例中,我们的目标是从数万个手写图像的数据集中正确识别数字。
数据介绍
数据文件 train.csv 和 test.csv 包含从 0 到 9 的手绘数字的灰度图像。
-
每个图像高 28 像素,宽28 像素,共784个像素。
-
每个像素取值范围[0,255],取值越大意味着该像素颜色越深
-
训练数据集(train.csv)共785列。第一列为 “标签”,为该图片对应的手写数字。其余784列为该图像的像素值
-
训练集中的特征名称均有pixel前缀,后面的数字([0,783])代表了像素的序号。
像素组成图像如下:
000 001 002 003 ... 026 027
028 029 030 031 ... 054 055
056 057 058 059 ... 082 083| | | | ...... | |
728 729 730 731 ... 754 755
756 757 758 759 ... 782 783
数据集示例如下:
# 导入工具包
import joblib
from sklearn.model_selection import train_test_split, GridSearchCV # 分割训练集和测试集的, 网格搜索 + 交叉验证.
from sklearn.neighbors import KNeighborsClassifier # KNN算法 分类对象
import matplotlib.pyplot as plt # 绘图.
import pandas as pd
from collections import Counter# 需求 定义函数 接收索引 将该行的手写数字 识别为 图片并绘制出来
def dm01_show_digit(idx):# 1. 读取文件 获取df对象data = pd.read_csv('./data/手写数字识别.csv')# 2.判断用户传入值 是否合法if idx < 0 or idx >= len(data):print('传入的索引有误 程序结束! ')return# 走到这里说明 没问题 查看下所有的数据集x = data.iloc[:, 1:]y = data.iloc[:, 0]print(f'数字的种类: {Counter(y)}') # Counter({1: 4684, 7: 4401, 3: 4351, 9: 4188, 2: 4177, 6: 4137, 0: 4132, 4: 4072, 8: 4063, 5: 3795})print(f'像素的形状: {x.shape}')# 根据传入的索引获取到该行的数据print(f'您传入的所有 对应的数字是: {y[idx]}')# 绘制图片# 把图片的像素点 转为 28*28的图片digit = x.iloc[idx].values.reshape(28, 28)# 绘制图片plt.imshow(digit, cmap='gray') # 灰度图plt.axis('off') # 关闭坐标# plt.savefig('./data/demo2.png')plt.show()# 需求2 定义函数 使用KNN算法 用于识别 手写数字 保存模型def dm02_train_mdoel():data = pd.read_csv('./data/手写数字识别.csv')# 数据预处理x = data.iloc[:, 1:]y = data.iloc[:, 0]x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=22, stratify=y)# 特征工程x_train = x_train / 255# 模型训练estimator = KNeighborsClassifier(n_neighbors=9)estimator.fit(x_train, y_train)# 模型评估print(f'准确率: {estimator.score(x_test, y_test)}')# 模型保存joblib.dump(estimator, './model/knn.pkl')def dm03_use_model():# 读取图片 绘制图片img = plt.imread('./data/demo.png')plt.imshow(img,cmap='gray')plt.show()# 读取模型 获取模型对象knn = joblib.load('./model/knn.pkl')# 模型预测y_predict = knn.predict(img.reshape(1,-1))print(f'预测结果为:{y_predict}')if __name__ == '__main__':# dm01_show_digit(20)# dm02_train_mdoel()dm03_use_model()
坚持分享 共同进步