人工智能应用-实验5-BP 神经网络分类手写数据集

文章目录

    • 🧡🧡实验内容🧡🧡
    • 🧡🧡代码🧡🧡
    • 🧡🧡分析结果🧡🧡
    • 🧡🧡实验总结🧡🧡

🧡🧡实验内容🧡🧡

编写 BP 神经网络分类, 实现对 MNIST 数据集分类的操作。


🧡🧡代码🧡🧡

需要配置torch。由于是小demo。为了提高效率,我采用的是google的colab进行实验编码,省去配环境的烦恼。

import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
from torch import nn, optim#@title 加载
transform = transforms.Compose([transforms.ToTensor(), # 转为张量,同时如果是图片(uint8)类型,会自动进行归一化到(0,1)transforms.Normalize( (0.5, ) , (0.5, ) ) # 转为std=0.5、mean=0.5的分布, 灰色图像,通道只有一个  将值域(0,1)再次转为(-1,1)])
train_set = datasets.MNIST('train_set', # 下载到该文件夹下download=not os.path.exists('train_set'), # 是否下载,如果下载过,则不重复下载train=True, # 是否为训练集transform=transform # 要对图片做的transform)
test_set = datasets.MNIST('test_set',download=not os.path.exists('test_set'),train=False,transform=transform)
test_set
# train_set[0][0]
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)dataiter = iter(train_loader)
images, labels = next(iter(dataiter))
print(images.shape)
print(labels.shape)#@title Bp net
class BP_Net(nn.Module):def __init__(self):super().__init__()"""定义第一个线性层,输入为图片(28x28),输出为第一个隐层的输入,大小为128。"""self.linear1 = nn.Linear(28 * 28, 128)self.relu1 = nn.ReLU() # 在第一个隐层使用ReLU激活函数"""定义第二个线性层,输入是第一个隐层的输出,输出为第二个隐层的输入,大小为64。"""self.linear2 = nn.Linear(128, 64)self.relu2 = nn.ReLU() # 在第二个隐层使用ReLU激活函数"""定义第三个线性层,输入是第二个隐层的输出,输出为输出层,大小为10"""self.linear3 = nn.Linear(64, 10)self.softmax = nn.LogSoftmax(dim=1) # 最终的输出经过softmax进行归一化def forward(self, x):"""定义神经网络的前向传播x: 输入的图片数据, shape为(64, 1, 28, 28)"""x = x.view(x.shape[0], -1) # 首先将x的shape转为(64, 784)# 进行前向传播x = self.linear1(x)x = self.relu1(x)x = self.linear2(x)x = self.relu2(x)x = self.linear3(x)x = self.softmax(x)return x
model = BP_Net()
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)#@title 评估
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve
model.eval() # 将模型设置为评估模式correct_count, all_count = 0, 0
predictions = [] # 预测结果列表
true_labels = [] # 真实标签列表for images,labels in test_loader: # 从test_loader中一批一批加载图片for i in range(len(labels)):logps = model(images[i])  # 进行前向传播,获取预测值probab = list(logps.detach().numpy()[0]) # 将预测结果转为概率列表。[0]是取第一张照片的10个数字的概率列表(因为一次只预测一张照片)pred_label = probab.index(max(probab)) # 取最大的index作为预测结果true_label = labels.numpy()[i]if(true_label == pred_label): # 判断是否预测正确correct_count += 1all_count += 1predictions.append(pred_label)true_labels.append(true_label)# 准确率
print("Number Of Images Tested =", all_count)
print("Model Accuracy =", (correct_count/all_count))# 混淆矩阵
def plot_confusion_matrix(cm, classes):plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)plt.title("Confusion Matrix")plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes)plt.yticks(tick_marks, classes)thresh = cm.max() / 2for i in range(cm.shape[0]):for j in range(cm.shape[1]):plt.text(j, i, format(cm[i, j], 'd'), ha="center", va="center",color="white" if cm[i, j] > thresh else "black")plt.ylabel('True Label')plt.xlabel('Predicted Label')plt.tight_layout()plt.show()cm = confusion_matrix(true_labels, predictions)
classes = [str(i) for i in range(10)]
plot_confusion_matrix(cm, classes)#@title 验证
model.train() # 切回训练模式## 验证本地图片
import cv2
from PIL import Image
for num in range(0,10):img = cv2.imread('./myImg/{}.jpg'.format(num), 0)  # 以灰度图的方式读取要预测的图片img = cv2.resize(img, (28, 28))height, width = img.shapedst = np.zeros((height, width), np.uint8)for i in range(height):for j in range(width):dst[i, j] = 255 - img[i, j]dst= dst / 255.0 #归一化dst = (dst - 0.5) / 0.5  # 标准化到[-1, 1]img = dst# print(img)img = np.array(img).astype(np.float32)img = np.expand_dims(img, 0)  # 扩展后,为[1,28,28]img = np.expand_dims(img, 0)  # 扩展后,为[1,1,28,28]img = torch.from_numpy(img)# print(img.shape)with torch.no_grad():output=model(img)# print(output.data)print(output.data.max(1)[1])

🧡🧡分析结果🧡🧡

数据预处理

  • 加载数据集:
    加载torch自带的minst数据集
  • 转换数据:
    先转为tensor变量(相当于直接除255归一化到值域为(0,1))
    在这里插入图片描述
    然后根据std=0.5,mean=0.5,再将值域标准化到(-1,1)
    在这里插入图片描述

设置基本参数:
在这里插入图片描述

构建BP神经网络:
如下,输入为一张2828图片,拆解成2828=784个特征,最终经过三个线性层(784,128)、(128、64)、(64,10),输出为10个特征(对应10个类),归一化这10个特征,它们的大小即认为它属于哪张图片的概率值,取出概率最大的特征对应的类别作为最终预测类别。
在这里插入图片描述

模型训练:
在这里插入图片描述
在这里插入图片描述

模型评估:
准确率:达到97.69%
在这里插入图片描述
混淆矩阵
在这里插入图片描述

接下来,分析网络层数对分类准确率的影响。
被对照试验:隐藏层数目改为2,神经元数目分别为128、64
准确率为:97.69%
对照实验1:隐藏层数目改为3,神经元数目分别为256、128、64
在这里插入图片描述
Loss图:
在这里插入图片描述
准确率和混淆矩阵如下:97.55%
在这里插入图片描述
对照实验2:隐藏层数目改为5,神经元数目分别为512、256、128、64、32
在这里插入图片描述
Loss图:
在这里插入图片描述
准确率和混淆矩阵:97.85%
在这里插入图片描述
总结结果如下表:
在这里插入图片描述
分析可知:

  • 运行时间:从实验结果来看,在增加隐藏层数的情况下,运行时间明显增加。
  • 准确率:实验结果显示,在增加隐藏层数的情况下,准确率大体上有所提升,但是总体变化幅度并不大,可能是因为epochs或者随机梯度下降等参数已经设为较优值,使得准确率已经接近最优效果,从而导致增加网络层数的提优空间并不明显。
    综合来看,增加隐藏层数对于提高分类准确率有一定的帮助,但是也会明显增加运行时间。其次,需要注意的是,若增加隐藏层数并非一定能够带来准确率的提升,过多的隐藏层可能会导致过拟合等问题。

🧡🧡实验总结🧡🧡

在完成基础实验上,我自己画了几张数字图,以对模型进行验证
在这里插入图片描述
结果如下,可以看到,对数字1和数字5分类错误(分布预测成了5和8),其余均分类正确,大体上效果良好。考虑原因,可能是因为minst的数据集是“黑底白字”,而我手画的图片则为“黑字白底”,导致了一些误差。
在这里插入图片描述
理论理解:
通过本次实验,大体上掌握了BP神经网络的定义和结构,总的来说,BP神经网络可以理解为一个黑盒子,通过不断根据loss进行反向传播,最终目的就是得到线性参数w和b,从而根据Y=wx+b 对输入的新x进行预测分类。
代码实践:
一开始想用纯numpy进行BP网络的编写,但是在编写后向传播时,可能是线代和高数知识有些遗忘,求导数时琢磨了很久。后面还是选择直接使用pytorch进行编写,也容易调参,方便进行实验。对我而言,代码中比较纠结的是shape的转换和传入,因此最好多查看中间过程的shape,以便更好理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/19164.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

windows 安装 使用 nginx

windows 安装 使用 nginx nginx官网下载地址:https://nginx.org/en/download.html 下载稳定版本即可 下载压缩包解压到即可 进入文件夹中,打开命令行窗口,执行启动命令 start nginx.exe验证(默认是80端口)&#x…

工程项目管理系统的Java实现:高效协同与信息共享

在当今的工程领域,项目管理的高效协同和信息共享是提升工作效率、降低成本的关键。本文将向您介绍一款基于Java技术构建的工程项目管理系统,该系统采用前后端分离的先进技术框架,功能全面,能够满足不同角色的需求,从项…

失落的方舟 命运方舟台服账号怎么注册 游戏账号最全图文注册教程

探索奇幻大陆阿克拉西亚的奥秘,加入《失落的方舟》(Lost Ark)这场史诗般的冒险。这是一款由Smilegate精心雕琢的MMORPG巨作,它融合了激烈动作战斗与深邃故事叙述,引领玩家步入一个因恶魔侵袭而四分五裂的世界。作为勇敢…

How Diffusion Models Work

introduction intuition goal 让神经网络学到图像是什么样的,一种方式是对数据添加不同级别的噪音,让神经网络能够区分细节/总体轮廓 训练一个神经网络去产生精灵 sampling nn

618局外人抖音:别人挤压商家“拼价格”,它默默联合商家“抢用户”?

文|新熔财经 作者|宏一 “618”来临之际,各电商平台和短视频平台早已打响了“促销大战”。不过,今年各大平台都更积极适应新的消费形式,调整了“大促动作”。 比如淘宝、京东带头取消了沿用十年之久的预售机制&…

Stream流模式通信及示例

Stream流模式通信是指在计算机网络中,数据作为连续的字节流传输而不是独立的数据包。它是一种面向连接的通信方式,常见于TCP(传输控制协议)。以下是Stream流模式通信的基本概念和一个简单的示例。 基本概念 面向连接&#xff1…

apollo版本更新简要概述

apollo版本更新简要概述 Apollo 里程碑版本9.0重要更新Apollo 开源平台 9.0 的主要新特征如下:基于包管理的 PnC 扩展开发范式基于包管理的感知扩展开发范式全新打造的 Dreamview Plus 开发者工具感知模型全面升级,支持增量训练 版本8.0版本6.0 Apollo 里…

异步编程的魔力:如何显著提升系统性能

异步编程的魔力:如何显著提升系统性能 今天我们来聊聊一个对开发者非常重要的话题——异步编程。异步编程是提升系统性能的一种强大手段,尤其在需要高吞吐量和低时延的场景中,异步设计能够显著减少线程等待时间,从而提升整体性能。 异步设计如何提升系统性能? 我们通过…

文件IO(二)

文件IO(二) 标准IO缓冲类型全缓冲行缓冲不缓冲 打开文件fopen 操作文件按字符读写(fgetc fputc)按行读写(fgets fputs)按块(对象)读写(fread fwrite)按格式化读写(fscanf…

stm32学习-CubeIDE使用技巧

1.hex文件生成 右键工程 2.仿真调试 3.常用快捷键 作用快捷键代码提示alt/代码注释/反注释ctrl/ 4.项目复制 复制项目,将ioc文件名改为项目名即可图形化编辑

泛型方法、泛型类

如果不需要把类型参数所表示的对象设为实例字段,那么应该优先考虑创建泛型方法,而不是泛型类 在两种情况下,必须把类写成泛型类: 第一种情况,该类需要将某个值用作其内部状态【属性的返回值、字段的返回值等】&#x…

springboot课程题库管理系统-计算机毕业设计源码30812

摘 要 随着科学技术的飞速发展,各行各业都在努力与现代先进技术接轨,通过科技手段提高自身的优势;对于课程题库管理系统 当然也不能排除在外,随着网络技术的不断成熟,带动了课程题库管理系统 ,它彻底改变了…

【刷题(12)】图论

一、图论问题基础 在 LeetCode 中,「岛屿问题」是一个系列系列问题,比如: 岛屿数量 (Easy)岛屿的周长 (Easy)岛屿的最大面积 (Medium)最大人工岛 (Hard&…

【考研数学】数学一和数学二哪个更难?如何复习才能上90分?

很明显考研数学一更难! 不管是复习量还是题目难度 对比项考研数学一考研数学二适用专业理工科类及部分经济学类理工科类考试科目高等数学、线性代数、概率论与数理统计高等数学、线性代数试卷满分150分150分考试时间180分钟180分钟试卷内容结构高等数学约60%&…

电脑怎么清理c盘垃圾文件 电脑运行内存不足怎么清理

和Windows系统电脑文件分区不同,苹果电脑并不分区,默认只有C盘,当C盘垃圾文件过多,电脑运行内存不足时,手动清理电脑垃圾文件毫无头绪,可以尝试使用苹果电脑清理软件——CleanMyMac来清理 。 一、电脑怎么…

React Hooks是如何保存的

React 函数式组件是没有状态的,需要 Hooks 进行状态的存储,那么状态是怎么存储的呢?Hooks是保存在 Fiber 树上的,多个状态是通过链表保存,本文将通过源代码分析 Hooks 的存储位置。 创建组件 首先我们在组件中添加两…

电商推荐系统+电影推荐系统【虚拟机镜像分享】

电商推荐系统电影推荐系统【虚拟机镜像分享】 所有组件部署好的镜像下载(在下面),仅供参考学习。(百度网盘,阿里云盘…) 博主通过学习尚硅谷电商推荐电影推荐项目,将部署好的虚拟机打包成ovf文…

设计模式复习

一、模式所采用的关系(e.g.继承…) UML图例 二、各模式的特点、优缺点 1.创建型 将对象的使用和创建分离,使用对象时无需知道对象的创建细节,使得创建过程可以多次复用,且修改两者中的一个对另一个影响为0或很少。 …

Stable Diffusion WebUI详细使用指南

Stable Diffusion WebUI(AUTOMATIC1111,简称A1111)是一个为高级用户设计的图形用户界面(GUI),它提供了丰富的功能和灵活性,以满足复杂和高级的图像生成需求。由于其强大的功能和社区的活跃参与&…

连锁收银系统支持带结算功能

连锁实体店的收银系统需要支持结算功能,以适应连锁运营效率和提升连锁管理的水平。商淘云连锁收银系统与您一起分享连锁收银系统需支持结算功能的三大必要点。大家点赞收藏,以免划走后找不到了。 一是,连锁模式的运营比较复杂,有加…