Scikit-learn 识别手写数字

Scikit-learn 识别手写数字的完整教程(包含各模型预测结果和准确率)

本教程将使用 Scikit-learn 提供的手写数字数据集,分别使用支持向量机 (SVM)、随机森林和逻辑回归三种模型进行训练,并展示它们的预测结果和准确率。

1. Scikit-learn 库架构概述

Scikit-learn 是一个流行的机器学习库,提供了大量用于分类、回归、聚类等任务的机器学习工具。我们将使用该库自带的手写数字数据集 (digits) 来构建模型。

2. 官方文档链接

Scikit-learn 官方文档

3. 手写数字数据集

Scikit-learn 提供了一个包含 1797 个 8x8 像素手写数字图像的数据集,标签为数字 0-9。这些图像可用于图像分类任务。

4. 数据集加载和预处理

我们首先加载数据集,并将每个图像展平为 64 维的特征向量(8x8 的像素值展平),然后将数据划分为训练集和测试集。

import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split# 加载手写数字数据集
digits = datasets.load_digits()# 展示数据集基本信息
print("数据集样本数量:", len(digits.images))
print("每张图片的尺寸:", digits.images[0].shape)# 显示一张手写数字图像
plt.gray()  # 设置为灰度图像
plt.matshow(digits.images[0])  # 显示第一个图像
plt.show()# 将 8x8 的图像展平成 64 维的一维向量
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(data, digits.target, test_size=0.5, random_state=42)

5. 模型训练与评估

我们将分别使用以下三种模型进行手写数字分类任务:

  • 支持向量机 (SVM)
  • 随机森林 (Random Forest)
  • 逻辑回归 (Logistic Regression)
5.1 支持向量机(SVM)模型
from sklearn import svm
from sklearn.metrics import classification_report, accuracy_score# 实例化 SVM 分类器
svm_classifier = svm.SVC(gamma=0.001)# 使用训练集进行模型训练
svm_classifier.fit(X_train, y_train)# 在测试集上进行预测
y_pred_svm = svm_classifier.predict(X_test)# 输出模型的准确率和分类报告
print("SVM 模型测试集上的准确率:", accuracy_score(y_test, y_pred_svm))
print("SVM 模型分类报告:\n", classification_report(y_test, y_pred_svm))
SVM 模型输出结果:
SVM 模型测试集上的准确率: 0.986652977412731
SVM 模型分类报告:precision    recall  f1-score   support0       1.00      1.00      1.00        881       0.97      1.00      0.98        912       0.98      0.98      0.98        863       1.00      0.99      0.99        914       0.99      0.98      0.98        925       0.97      0.98      0.97        916       0.98      0.98      0.98        917       1.00      0.98      0.99        898       0.97      0.97      0.97        889       0.98      0.95      0.97        89accuracy                           0.99       896macro avg       0.99      0.99      0.99       896
weighted avg       0.99      0.99      0.99       896
5.2 随机森林模型
from sklearn.ensemble import RandomForestClassifier# 实例化随机森林分类器
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)# 使用训练集进行模型训练
rf_classifier.fit(X_train, y_train)# 在测试集上进行预测
y_pred_rf = rf_classifier.predict(X_test)# 输出模型的准确率和分类报告
print("随机森林模型测试集上的准确率:", accuracy_score(y_test, y_pred_rf))
print("随机森林模型分类报告:\n", classification_report(y_test, y_pred_rf))
随机森林模型输出结果:
随机森林模型测试集上的准确率: 0.9669642857142857
随机森林模型分类报告:precision    recall  f1-score   support0       1.00      1.00      1.00        881       0.96      0.99      0.97        912       0.99      0.97      0.98        863       1.00      0.98      0.99        914       0.99      0.97      0.98        925       0.98      0.97      0.98        916       0.96      1.00      0.98        917       0.98      0.98      0.98        898       0.94      0.93      0.94        889       0.90      0.89      0.89        89accuracy                           0.97       896macro avg       0.97      0.97      0.97       896
weighted avg       0.97      0.97      0.97       896
5.3 逻辑回归模型
from sklearn.linear_model import LogisticRegression# 实例化逻辑回归模型
lr_classifier = LogisticRegression(max_iter=10000)# 使用训练集进行模型训练
lr_classifier.fit(X_train, y_train)# 在测试集上进行预测
y_pred_lr = lr_classifier.predict(X_test)# 输出模型的准确率和分类报告
print("逻辑回归模型测试集上的准确率:", accuracy_score(y_test, y_pred_lr))
print("逻辑回归模型分类报告:\n", classification_report(y_test, y_pred_lr))
逻辑回归模型输出结果:
逻辑回归模型测试集上的准确率: 0.9464285714285714
逻辑回归模型分类报告:precision    recall  f1-score   support0       1.00      1.00      1.00        881       0.94      0.99      0.96        912       0.98      0.96      0.97        863       1.00      0.97      0.98        914       0.97      0.97      0.97        925       0.96      0.98      0.97        916       0.97      0.99      0.98        917       0.95      0.94      0.95        898       0.88      0.85      0.87        889       0.86      0.82      0.84        89accuracy                           0.95       896macro avg       0.95      0.95      0.95       896
weighted avg       0.95      0.95      0.95       896

6. 预测结果的可视化

为了直观展示模型的预测结果,我们定义一个函数来可视化部分手写数字图像,并显示实际标签和模型的预测标签。

# 定义一个函数来展示部分预测结果
def display_predictions(images, predictions, labels, num_images=5):plt.figure(figsize=(10, 5))for i in range(num_images):plt.subplot(1, num_images, i + 1)plt.imshow(images[i].reshape(8, 8), cmap='gray')plt.title(f'预测: {predictions[i]}\n实际: {labels[i]}')plt.axis('off')plt.show()# 展示各模型的部分预测结果
print("SVM 模型的部分预测结果:")
display_predictions(X_test, y_pred_svm, y_test)print("随机森林模型的部分预测结果:")
display_predictions(X_test, y_pred_rf, y_test)print("逻辑回归模型的部分预测结果:")
display_predictions(X_test, y_pred_lr, y_test)

7. 完整代码汇总

以下是完整的代码片段,包含数据加载、模型训练、预测结果输出和可视化。

import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn.metrics import classification_report, accuracy_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression# 加载手写数字数据集
digits = datasets.load_digits()# 数据预处理
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
X_train, X_test, y_train, y_test = train_test_split(data, digits.target, test_size=0.5, random_state=42)# 支持向量机 (SVM) 模型
svm_classifier = svm.SVC(gamma=0.001)
svm_classifier.fit(X_train, y_train)
y_pred_svm = svm_classifier.predict(X_test)
print("SVM 模型测试集上的准确率:", accuracy_score(y_test, y_pred_svm))
print("SVM 模型分类报告:\n", classification_report(y_test, y_pred_svm))# 随机森林模型
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)
rf_classifier.fit(X_train, y_train)
y_pred_rf = rf_classifier.predict(X_test)
print("随机森林模型测试集上的准确率:", accuracy_score(y_test, y_pred_rf))
print("随机森林模型分类报告:\n", classification_report(y_test, y_pred_rf))# 逻辑回归模型
lr_classifier = LogisticRegression(max_iter=10000)
lr_classifier.fit(X_train, y_train)
y_pred_lr = lr_classifier.predict(X_test)
print("逻辑回归模型测试集上的准确率:", accuracy_score(y_test, y_pred_lr))
print("逻辑回归模型分类报告:\n", classification_report(y_test, y_pred_lr))# 展示部分预测结果
def display_predictions(images, predictions, labels, num_images=5):plt.figure(figsize=(10, 5))for i in range(num_images):plt.subplot(1, num_images, i + 1)plt.imshow(images[i].reshape(8, 8), cmap='gray')plt.title(f'预测: {predictions[i]}\n实际: {labels[i]}')plt.axis('off')plt.show()# 展示各模型的预测结果
print("SVM 模型的部分预测结果:")
display_predictions(X_test, y_pred_svm, y_test)print("随机森林模型的部分预测结果:")
display_predictions(X_test, y_pred_rf, y_test)print("逻辑回归模型的部分预测结果:")
display_predictions(X_test, y_pred_lr, y_test)

8. 总结

  • SVM 模型:在手写数字识别任务中的表现最好,达到了 98.67% 的准确率。
  • 随机森林模型:表现也不错,准确率为 96.70%
  • 逻辑回归模型:作为线性模型,尽管表现稍差一些,但也达到了 94.64% 的准确率。

这三种模型的表现都比较优异,具体选择哪种模型取决于任务的复杂性、数据量和计算资源。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/880181.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Frontiers出版社系列SCISSCI合集

【SciencePub学术】本期,小编根据WOS数据库,整理了一下Frontiers出版社系列的SCI&SSCI合集,以供各位学者投稿参考! 来源:WOS数据库 Frontiers系列期刊中,Frontiers in Immunology以其5.7分的影响因子位…

第十四届蓝桥杯嵌入式国赛

一. 前言 本篇博客主要讲述十四届蓝桥杯嵌入式的国赛题目,包括STM32CubeMx的相关配置以及相关功能实现代码以及我在做题过程中所遇到的一些问题和总结收获。如果有兴趣的伙伴还可以去做做其它届的真题,可去 蓝桥云课 上搜索历届真题即可。 二. 题目概述 …

每日一练:二叉树的层序遍历

102. 二叉树的层序遍历 - 力扣(LeetCode) 一、题目要求 给你二叉树的根节点 root ,返回其节点值的 层序遍历 。 (即逐层地,从左到右访问所有节点)。 示例 1: 输入:root [3,9,20,n…

合宙LuatOS应用,与时间相关那些事

合宙嵌入式操作系统LuatOS——在蜂窝物联网模组上推出开源二次开发框架,功能齐全性能稳定,可大幅度降低用户的研发成本和研发周期。 在LuatOS中,获取时间函数用得最多的就是os.time()函数了。接下来,我会讲一些与这个函数以及其他…

c++924

2 #include <iostream> #include <cstring>using namespace std;class MyString { private:char *str; // 记录C风格的字符串int size; // 记录字符串的实际长度public:// 定义无参构造MyString() : size(0) {str new char[1];str[0] \0;cou…

中秋节特别游戏:给玉兔投喂月饼

&#x1f5bc;️ 效果展示 &#x1f4dc; 游戏背景 在中秋这个充满诗意的节日里&#xff0c;玉兔因为贪玩被赶下人间。在这个温柔的夜晚&#xff0c;我们希望通过一个小游戏&#xff0c;让玉兔感受到人间的温暖和关怀。&#x1f430;&#x1f319; &#x1f3ae; 游戏设计 人…

Oracle数据库的比较运算符Comparison Operators

Comparison operators compare one expression to another. The result is always either TRUE, FALSE, or NULL. If the value of one expression is NULL, then the result of the comparison is also NULL. 如果一个表达式的值为NULL&#xff0c;那么比较的结果也是NULL。 …

5、论文阅读:深水下的图像增强

深水下的图像增强 前言介绍贡献UWCNN介绍网络架构残差Residuals块 Blocks网络层密集串联网络深度减少边界伪影网络损失Loss后处理前言 水下场景中,与波长相关的光吸收和散射会降低图像的可见度,导致对比度低和色偏失真。为了解决这个问题,我们提出了一种基于卷积神经网络的…

Rust调用tree-sitter解析C语言

文章目录 一、Rust 调用 tree-sitter 解析 C 语言代码1. 设置 Rust 项目2. 添加 tree-sitter 依赖3. 编写 Rust 代码4. 运行程序5. 编译出错 二、解决步骤1. 添加 tree-sitter 构建依赖2. 添加 tree-sitter-c 源代码3. 修改 build.rs 以编译 tree-sitter-c 库4. 修改 Cargo.tom…

Ubuntu中常用的操作指令

ubuntu中常通过在命令行中输入各种指令完成操作。 文件操作指令 ls&#xff1a;列出目录内容 ls cd&#xff1a;改变当前目录 # 进入指定目录 cd /path/to/directory # 返回上一级目录 cd .. # 返回用户主目录 cd ~ cp&#xff1a;复制文件或目录 # 复制文件 …

伊犁云计算22-1 apache 安装rhel8

1 局域网网络必须通 2 yum 必须搭建成功 3 apache 必须安装 开干 要用su 用户来访问 一看httpd 组件安装完毕 到这里就是测试成功了 如何修改主页的目录 网站目录默认保存在/var/WWW/HTML 我希望改变/home/www 122 127 167 行要改

频率色散效应及其与时间选择性衰落信道的联系

频率色散效应&#xff08;Frequency Dispersion Effect&#xff09;是在无线通信中&#xff0c;由于信道中的多普勒效应引起的现象&#xff0c;它会导致接收信号频谱的扩展和频率上的变化。该效应与信道的时间变化有关&#xff0c;是时间选择性衰落信道&#xff08;time-select…

打造灵活DateTimePicker日期时间选择器组件:轻松实现时间的独立清除功能

element ui中日期和时间选择器&#xff08;DateTimePicker&#xff09;是一个常见且重要的组件。它允许用户轻松地选择日期和时间&#xff0c;极大地提升了用户体验。然而&#xff0c;在某些场景下&#xff0c;用户可能需要更细粒度的控制&#xff0c;例如单独清除已选择的时间…

Swagger 概念和使用以及遇到的问题

前言 接口文档对于前后端开发人员都十分重要。尤其近几年流行前后端分离后接口文档又变 成重中之重。接口文档固然重要,但是由于项目周期等原因后端人员经常出现无法及时更新&#xff0c; 导致前端人员抱怨接口文档和实际情况不一致。 很多人员会抱怨别人写的接口文档不…

mysql性能优化-延迟写和异步写优化

MySQL 性能优化中的延迟写和异步写优化是数据库写入操作中非常重要的技术手段。这些技术可以有效减少磁盘 I/O 操作、提高数据库的吞吐量和整体性能。尤其是在高并发写操作场景下&#xff0c;通过优化写入过程&#xff0c;减少阻塞和等待时间&#xff0c;可以大幅度提升系统的响…

Cassandra 5.0 Spring Boot 3.3 CRUD

概览 因AI要使用到向量存储&#xff0c;JanusGraph也使用到Cassandra 卸载先前版本 docker stop cassandra && docker remove cassandra && rm -rf cassandra/运行Cassandra容器 docker run \--name cassandra \--hostname cassandra \-p 9042:9042 \--pri…

【HarmonyOS】深入理解@Observed装饰器和@ObjectLink装饰器:嵌套类对象属性变化

【HarmonyOS】深入理解Observed装饰器和ObjectLink装饰器&#xff1a;嵌套类对象属性变化 前言 之前就Observed和ObjectLink写过一篇讲解博客【HarmonyOS】 多层嵌套对象通过ObjectLink和Observed实现渲染更新处理&#xff01; 其中就Observe监听类的使用&#xff0c;Object…

ZXing.Net:一个开源条码生成和识别器,支持二维码、条形码等

推荐一个跨平台的非常流行的条码库&#xff0c;方便我们在.Net项目集成条码扫描和生成功能。 01 项目简介 ZXing.Net是ZXing的.Net版本的开源库。支持跨多个平台工作&#xff0c;包括 Windows、Linux 和 macOS&#xff0c;以及在 .NET Core 和 .NET Framework 上运行。 解码…

硬件看门狗导致MCU启动时间慢

最近&#xff0c;在项目交付过程中&#xff0c;我们遇到了一个有趣的问题&#xff0c;与大家分享一下。 客户的需求是&#xff1a;在KL15电压上电后&#xff0c;MCU需要在200ms内发送出第一包CAN报文数据。然而&#xff0c;实际测试结果显示&#xff0c;软件需要360ms才能发送…

【通俗易懂介绍OAuth2.0协议以及4种授权模式】

文章目录 一.OAuth2.0协议介绍二.设计来源于生活三.关于令牌与密码的区别四.应用场景五.接下来分别简单介绍下四种授权模式吧1.客户端模式1.1 介绍1.2 适用场景1.3 时序图 2.密码模式2.1 介绍2.2 适用场景2.3时序图 3.授权码模式3.1 介绍3.2 适用场景3.3 时序图 4.简化模式4.1 …