机器学习实验4——CNN卷积神经网络分类Minst数据集

文章目录

    • 🧡🧡实验内容🧡🧡
    • 🧡🧡 原理🧡🧡
    • 🧡🧡CNN实现分类Minst🧡🧡
      • 代码
      • 数据预处理:
      • 设置基本参数:

🧡🧡实验内容🧡🧡

基于手写minst数据集,完成关于卷积网络CNN的模型训练、测试与评估。

🧡🧡 原理🧡🧡

卷积层
通过使用一组可学习的滤波器(也称为卷积核)对输入图像进行滑动窗口卷积操作,这样可以提取出不同位置的局部特征,从而捕捉到图像的空间结构信息。
激活函数
在卷积层之后,通常会应用一个非线性激活函数,如 ReLU激活函数的作用是引入非线性,使得 CNN 能够学习更复杂的特征表达。
池化层
池化层用于降低特征图的空间尺寸,同时保留最显著的特征信息(类似于人眼观物,是根据物体的主要轮廓来判断物体是什么,而对一些小细节第一眼并没有那么关注)。常见的池化方式包括最大池化和平均池化,它们可以减少计算量,并增加模型的平移不变性。
全连接层
一般在 CNN 的最后几层,全连接层被用来将先前的卷积和池化层的输出与目标类别进行关联,每个神经元在该层中与前一层的所有神经元相连,通过学习权重参数来进行分类决策。
Softmax 函数
在最后一个全连接层之后,通常会应用 Softmax 函数来将神经网络的输出转换为概率分布,用于多类别分类问题的预测。
例如p=[0.2,0.3,0.5],这表示分类为类别1、2、3的概率分别为0.2,0.3,0.5,因此预测分类结果为类别3.
最后通过反向传播算法,CNN 使用训练数据进行模型参数的优化,它通过最小化损失函数(如交叉熵)来调整网络权重,并使用梯度下降等优化算法进行迭代更新。

构建本实验的CNN网络:

  • 5 x 5的卷积核,输入通道为1,输出通道为16:此时图像矩经过卷积核后尺寸变成24 x 24。
  • 2 x 2 的最大池化层:此时图像大小缩短一半,变成 12 x 12,通道数不变;
  • 再次经过 5 x 5 的卷积核,输入通道为16,输出通道为32:此时图像尺寸经过卷积核后变成8 *8。
  • 再次经过 2 x 2 的最大池化层:此时图像大小缩短一半,变成4 x 4,通道数不变;
  • 最后将图像整型变换成向量,输入到全连接层中:输入一共有4 x 4 x 32 = 512 个元素,输出为10.
    -在这里插入图片描述

🧡🧡CNN实现分类Minst🧡🧡

代码

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from torch import nn, optim
from time import time# ======================准备数据集======================
batch_size = 64
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='../dataset/mnist/',train=True,download=True,transform=transform)
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist',train=False,download=True,transform=transform)
test_loader = DataLoader(test_dataset,shuffle=False,batch_size=batch_size)# ======================CNN net======================
class CNN_net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 16, kernel_size=5)  # 卷积1self.pooling1 = nn.MaxPool2d(2)  # 最大池化self.relu1 = nn.ReLU()  # 激活self.conv2 = nn.Conv2d(16, 32, kernel_size=5)self.pooling2 = nn.MaxPool2d(2)self.relu2 = nn.ReLU()self.fc = nn.Linear(512, 10)  # 全连接def forward(self, x):batch_size = x.size(0)x = self.conv1(x)x = self.pooling1(x)x = self.relu1(x)x = self.conv2(x)x = self.pooling2(x)x = self.relu2(x)x = x.view(batch_size, -1)x = self.fc(x)return xmodel = CNN_net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)# ====================== train ======================
def train(epoch):time0 = time()  # 记录下当前时间loss_list = []for e in range(epoch):running_loss = 0.0for images, labels in train_loader:outputs = model(images)  # 前向传播获取预测值loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 进行反向传播optimizer.step()  # 更新权重optimizer.zero_grad()  # 清空梯度running_loss += loss.item()  # 累加损失# 一轮循环结束后打印本轮的损失函数print("Epoch {} - Training loss: {}".format(e, running_loss / len(train_loader)))loss_list.append(running_loss / len(train_loader))# 打印总的训练时间print("\nTraining Time (in minutes) =", (time() - time0) / 60)# 绘制损失函数随训练轮数的变化图plt.plot(range(1, epoch + 1), loss_list)plt.xlabel('Epochs')plt.ylabel('Loss')plt.title('Training Loss')plt.show()train(5)# ====================== test ======================
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as pltdef test():model.eval()  # 将模型设置为评估模式correct = 0total = 0all_predicted = []all_labels = []with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()all_predicted.extend(predicted.tolist())all_labels.extend(labels.tolist())print('Model Accuracy =:%.4f' % (correct / total))# 绘制混淆矩阵cm = confusion_matrix(all_labels, all_predicted)plt.figure(figsize=(8, 6))sns.heatmap(cm, annot=True, fmt=".0f", cmap="Blues")plt.xlabel("Predicted Labels")plt.ylabel("True Labels")plt.title("Confusion Matrix")plt.show()test()

数据预处理:

加载数据集:
加载torch库中自带的minst数据集
转换数据:
先转为tensor变量(相当于直接除255归一化到值域为(0,1))
然后根据std=0.5,mean=0.5,再将值域标准化到(-1,1)。
(做完实验后,上网了解发现minst最合适的的std和mean分别为0.1307, 0.3081,但是其实结果都差不多,准确率变化不大,因为数据集还是相对比较简单的)

设置基本参数:

在这里插入图片描述

构建CNN神经网络:
同上述(1)中,已经构建完毕,这里不再赘述。
模型训练:
在这里插入图片描述
可见,虽然只经过5个epoch,但是花的时间为3.3min。
模型分类:
在这里插入图片描述
准确率达98.26%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/643084.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

接口文档swagger2的使用

Spring-接口文档swagger2 1、swagger/knife4j 接口文档配置 ​ knife4j是swagger的增强版本&#xff0c;更加的小巧、轻量&#xff0c;功能也是更加的完善&#xff0c;UI也更加的清晰&#xff1b;可以从swagger到knife4j无缝切换。 1.1 引入相关依赖 <!--接口文档的开发:…

神经网络:表述(Neural Networks: Representation)

1.非线性假设 无论是线性回归还是逻辑回归&#xff0c;当特征太多时&#xff0c;计算的负荷会非常大。 案例&#xff1a; 假设我们有非常多的特征&#xff0c;例如大于 100 个变量&#xff0c;我们希望用这 100 个特征来构建一个非线性的多项式模型&#xff0c;结果将是数量非…

Win10 如何用powershell写个WOL开机脚本

环境&#xff1a; Win10 专业版 问题描述&#xff1a; Win10 如何用powershell写个WOL开机脚本 解决方案&#xff1a; 1.脚本内容 $mac b1-10-18-52-11-12 $macBytes $mac -split - | ForEach-Object { [byte](0x $_) } $broadcastAddress [byte[]](1..6 | ForEach-O…

springboot导出数据到excel模板,使用hutool导出数据到指定excel,java写入数据到excel模板

最近遇到一个需求&#xff0c;需要从数据库查询数据&#xff0c;写入到对应的excel导入模板中。再把导出的数据进行修改&#xff0c;上传。 我们项目用的是easyExcel&#xff0c;一顿百度搜索&#xff0c;不得其法。 主要是要把数据填充到指定单元格中&#xff0c;跟平时用到的…

【Android】细数Linux和Android系统中的伪文件系统

文章目录 前言Linux伪文件系统cgroupfsLinux的cgroupsAndroid的cgroups debugfsfunctionfs(/dev/usb-ffs/adb)functionfs 的引入sysfs是什么 procfs(/proc)pstore(/sys/fs/pstore)selinuxfs(/sys/fs/selinux)sysfs(/sys)参考 前言 做了好些年Android开发&#xff0c;你了解过L…

Java Web(二)--HTML

基本介绍 官网文档地址: HTML 教程 HTML&#xff08;HyperText Mark-up Language&#xff09;即超文本标签语言&#xff1b;HTML 文本是由 HTML 标签组成的文本&#xff0c;可以包括文字、图形、动画、声音、表格、链接等&#xff1b;HTML 的结构包括头部&#xff08;Head&…

学校“数据结构”课程Project—扩展功能(自主设计)

目录 一、设想功能描述 想法缘起 目标功能 二、问题抽象 三、算法设计和优化 1. 易想的朴素搜索 / dp 搜索想法 动态规划&#xff08;dp&#xff09;想法 2. 思考与优化 四、算法实现 五、结果示例 附&#xff1a;使用的地图API 一、设想功能描述 想法缘起 OSM 导出…

汽车网络架构与常用总线汇总

汽车CAN总线简述 CAN 是控制器局域网Controller Area Network 的缩写&#xff0c;1986年&#xff0c;由德国Bosch公司为汽车开发的网络技术&#xff0c;主要用于汽车的监测与控制&#xff0c;目的为适应汽车“减少线束的数量”“通过多个网络进行大量数据的高速传输”的需求。…

TA百人计划学习笔记 3.1.1模板测试

资料 源视频 【技术美术百人计划】图形 3.1 深度与模板测试 传送门效果示例_哔哩哔哩_bilibili ppt 3100-模板测试与深度测试(1) 参考 Unity Shader: 理解Stencil buffer并将它用于一些实战案例&#xff08;描边&#xff0c;多边形填充&#xff0c;反射区域限定&#xff0c;阴影…

c++学习笔记-STL案例-机房预约系统6-老师模块

前言 衔接上一篇“c学习笔记-STL案例-机房预约系统5-学生模块”&#xff0c;本文主要设计老师模块&#xff0c;从&#xff0c;老师登录和注销、查看所有预约、审核预约三个方面进行分析和实现。 目录 9 教师模块 9.1 教师登录和注销 9.1.1 构造函数 9.1.2 教师子菜单 ​编…

Linux7 安装 Oracle 19C RAC 详细图文教程

实战篇&#xff1a;Linux7 安装 Oracle 19C RAC 详细图文教程 本文是按照&#xff1a;https://www.modb.pro/db/154424的思路进行编写 一、安装前规划 安装RAC前&#xff0c;当然要先做好规划。具体包含以下几方面&#xff1a; 节点主机版本主机名实例名Grid/Oracle版本Publi…

鸿蒙原生开发-仿ChatGPT应用实战

运行环境 DAYU200:4.0.10.16 SDK&#xff1a;4.0.10.15 IDE&#xff1a;4.0.600 前言 在配置好环境之后&#xff0c;可以尝试这编写一个较为简单的应用程序练练手&#xff0c;这里选择使用一个免费的API接口网站 ALAPI来尝试编写一个可进行对话的GPT应用程序。 创建项目 …

SQL注入示例

例一、基础SQL注入&#xff1a;load_file读文件 CISP-PTE 认证考试 首先是有单引号和括号的&#xff0c;首要是要闭合&#xff0c;然后回显点是在-1的位置&#xff0c;读取文件上面的key的话使用的是load_file(/tmp/360/key) id-1)%09ununionion%09select%091,2,3,load_file…

【算法与数据结构】322、LeetCode零钱兑换

文章目录 一、题目二、解法三、完整代码 所有的LeetCode题解索引&#xff0c;可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、解法 思路分析&#xff1a;本题可以抽象成一个完全背包问题。 第一步&#xff0c; d p [ j ] dp[j] dp[j]的含义。 d p [ j ] dp…

Unity之Cinemachine教程

前言 Cinemachine是Unity引擎的一个高级相机系统&#xff0c;旨在简化和改善游戏中的相机管理。Cinemachine提供了一组强大而灵活的工具&#xff0c;可用于创建令人印象深刻的视觉效果&#xff0c;使开发人员能够更轻松地掌控游戏中的摄像机行为。 主要功能和特性包括&#x…

Springboot+vue的医院后台管理系统(有报告),Javaee项目,springboot vue前后端分离项目

演示视频&#xff1a; Springbootvue的医院后台管理系统&#xff08;有报告&#xff09;&#xff0c;Javaee项目&#xff0c;springboot vue前后端分离项目 项目介绍&#xff1a; 本文设计了一个基于Springbootvue的前后端分离的医院后台管理系统&#xff0c;采用M&#xff08…

LeetCode、875. 爱吃香蕉的珂珂【中等,最小速度二分】

文章目录 前言LeetCode、875. 爱吃香蕉的珂珂【中等&#xff0c;最小速度二分】题目及分类思路分析及代码实现代码优化 资料获取 前言 博主介绍&#xff1a;✌目前全网粉丝2W&#xff0c;csdn博客专家、Java领域优质创作者&#xff0c;博客之星、阿里云平台优质作者、专注于Ja…

如何修改flutter的minSdkVersion版本?

在使用第三方插件的时候&#xff0c;插件对最低的 minSdkVersion版本是有要求的&#xff0c;你比如flutter 插件 webview_flutter 就会报一下错&#xff1a; minSdkVersion 16 cannot be smaller than version 19 declared in library 解决方法①&#xff1a; 这个时候我们需…

Python爬虫框架选择与使用:推荐几个常用的高效爬虫框架

目录 前言 一、Scrapy框架 1. 安装Scrapy 2. Scrapy示例代码 3. 运行Scrapy爬虫 二、Beautiful Soup库 1. 安装Beautiful Soup 2. Beautiful Soup示例代码 3. 运行Beautiful Soup代码 三、Requests库 1. 安装Requests库 2. Requests示例代码 3. 运行Requests代码 …

【蓝桥杯--图论】最小生成树prim、kruskal

今日语录&#xff1a;成功不是终点&#xff0c;失败不是致命&#xff0c;勇气才是取胜的关键。 文章目录 prim算法kruskal算法(稀疏图) prim算法 #include <cstring> #include <algorithm> #include <iostream>#define _CRT_SECURE_NO_WARNINGS using names…