神经网络识别数字图像案例

学习资料:从零设计并训练一个神经网络,你就能真正理解它了_哔哩哔哩_bilibili

这个视频讲得相当清楚。本文是学习笔记,不是原创,图都是从视频上截图的。

1. 神经网络

2. 案例说明

具体来说,设计一个三层的神经网络。以数字图像作为输入,经过神经网络的计算,识别出图像中的数字是几,从而实现数字图像的分类。

3. 视频讲解内容的提纲

4. 神经网络的设计和实现

我们要处理的数据是28*28像素的灰色通道图像。

这样的灰色图像包括了28*28=784个数据点。需要先将他展平为1*784大小的向量。然后将这个向量输入到神经网络中。

用一个三层神经网络处理图片对应的向量X。输入成需要接收784维的图片向量X。X里面每个维度的数据都有一个神经元来接收。因此输入层要包含784个神经元。

隐藏成用于特征提取特征向量,将输入的特征向量处理成更高级的特征向量。

因为手写数字图像识别并不复杂,所以将隐藏层的神经元个数设置为256。这样,输入层和隐藏层之间就会有个784*256的线性层。它可以将一个784维的输入向量转换为256维的输出向量。

该输出向量会继续向前传播到达输出层。

由于最终要将数字图像识别为0~9,十种可能的数字。因此,输出层需要定义10个神经元,对应这十种数字。

256维的向量在经过隐藏层和输出层之间的线性层计算后,就得到了10维的输出结果。这个10维的向量就代表了10个数字的预测得分。

为了继续得到输出层的预测概率,还要将输出层的输出输入到softmax层。softmax层会将10维的向量转换为10个概率值p0~p9。p0~p9相加的总和等于1.

5. 神经网络的Pytorch实现

import torch
from torch import nn# 定义神经网络Network
class Network(nn.Module):def __init__(self):super().__init__()# 线性层1,输入层和隐藏层之间的线性层self.layer1 = nn.Linear(784, 258)# 线性层2,隐藏层和输出层之间的线性层self.layer2 = nn.Linear(256, 10)# 在前向传播,forward函数中,输入为图像xdef forward(self, x):x = x.view(-1, 28 * 28) # 使用view函数,将x展平x = self.layer1(x) # 将x输入到layer1x = torch.relu(x) # 使用relu激活return self.layer2(x) # 输入至layer2计算结果# 这里没有直接定义softmax层,因为后面会使用CrossEntropyLoss损失函数# 在这个损失函数中,会实现softmax的计算

6. 训练数据的准备和处理

from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 初学只要知道大致的数据处理流程即可
if __name__ == '__main__'# 实现图像的预处理pipelinetransform = trnasforms.Compose([# 转换成单通道灰度图transforms.Grayscale(num_output_channels=1),# 转换为张量transforms.ToTensor()])# 使用ImageFolder函数,读取数据文件夹,构建数据集dataset# 这个函数会将保持数据的文件夹的名字,作为数据的标签,组织数据train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)# 打印他们的长度print("train_dataset length: ", len(train_dataset))print("test_dataset length: ", len(test_dataset))# 使用train_loader, 实现小批量的数据读取# 这里设置小批量的大小,batch_size=64. 也就是每个批次,包括64个数据train_loader = DataLoader(train_datase, batch_size=64, shuffle=True)# 打印train_loader的长度print("train_loader length: ", len(train_loader))# 6000个训练数据,如果每个小批量,读入64个样本,那么60000个数据会被分成938组# 938*64=60032,说明最后一组不够64个数据# 循环遍历train_loader# 每一次循环,都会取出64个图像数据,作为一个小批量batchfor batch_idx, (data, label) in enumerate(train_loader)if batch_idx == 3:breakprint("batch_idx: ", batch_idx)print("data.shape: ", data.shape) # 数据的尺寸print("label: ", label.shape) # 图像中的数字print(label)

7. 模型的训练和测试

import torch
from torch import nn
from torch import optim
from model import Network
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoaderif __name__ == '__main__'# 图像的预处理transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# 读入并构造数据集train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)print("train_dataset length: ", len(train_dataset))# 小批量的数据读入train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)print("train_loader length: ", len(train_loader))# 在使用Pytorch训练模型时,需要创建三个对象:model = Network() # 1.模型本身,就是我们设计的神经网络optimizer = optim.Adam(model.parameters()) #2.优化器,优化模型中的参数criterion = nn.CrossEntropyLoss() #3.损失函数,分类问题,使用交叉熵损失误差# 进入模型的循环迭代# 外层循环,代表了整个训练数据集的遍历次数for epoch in range(10):# 内层循环使用train_loader, 进行小批量的数据读取for batch_idx, (data, label) in enumerate(train_loader):# 内层每循环一次,就会进行一次梯度下降算法# 包括了5个步骤# 这5个步骤是使用pytorch框架训练模型的定式,初学时先记住即可# 1. 计算神经网络的前向传播结果output = model(data)# 2. 计算output和标签label之间的损失lossloss = criterion(output, label)# 3. 使用backward计算梯度loss.backward()# 4. 使用optimizer.step更新参数optimizer.step()# 5.将梯度清零optimizer.zero_grad()if batch_idx % 100 == 0:print(f"Epoch {epoch + 1}/10"f"| Batch {batch_idx}/{len(train_loader)}"f"| Loss: {loss.item():.4f}")torch.save(model.state_dict(), 'mnist.pth')

from model import Network
from torchvision import transforms
from torchvision import datasets
import torchif __name__ == '__main__'transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# 读取测试数据集test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)print("test_dataset length: ", len(test_dataset))model = Network() # 定义神经网络模型model.load_state_dict(torch.load('mnist.pth')) # 加载刚刚训练好的模型文件rigth = 0 # 保存正确识别的数量for i, (x, y) in enumerate(test_dataset):output = model(x) # 将其中的数据x输入到模型predict = output.argmax(1).item() # 选择概率最大标签的作为预测结果# 对比预测值predict和真实标签yif predict == y:right += 1else:# 将识别错误的样例打印出来img_path = test_dataset.samples[i][0]print(f"wrong case: predict = {predict} y = {y} img_path = {img_path}")# 计算出测试效果sample_num = len(test_dataset)acc = right * 1.0 / sample_numprint("test accuracy = %d / %d = %.3lf" % (right, sample_num, acc))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/43766.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何找工作 校招 | 社招 | 秋招 | 春招 | 提前批

马上又秋招了,作者想起以前读书的时候,秋招踩了很多坑,但是第一份工作其实挺重要的。这里写一篇文章,分享一些校招社招的心得。 现在大学的情况是,管就业的人,大都是没有就业的辅导员(笔者见过…

亿发512版本更新,看数据驾驶舱、扫码拣货、UDI序列号的新功能

如果您正寻求突破传统业务模式的束缚,希望拥抱数字化转型带来的无限可能,我们诚邀您体验亿发软件。亿发专业团队将为您提供个性化的咨询和定制服务,帮助您的企业快速适应市场变化,实现业务模式和商业模式的创新。

【腾讯云生成式AI产品解决方案深度分析 2024】

文末有福利! 腾讯云生成式AI产品解决方案 (一) 基于生成式AI的腾讯云产品架构升级 (二) 腾讯云完善的产品矩阵,满足不同路线客户需求 1. 路线一 标准软件 (1) 腾讯乐享AI助手 落地背景及挑战在企业知识管理、培训学习、办公协同场景中,存…

初识C++ | 基本介绍、命名空间、输入输出、缺省函数、函数重载、引用、内联函数、nullptr

基本介绍 C的起源 1979年,当时的 Bjarne Stroustrup 正在⻉尔实验室从事计算机科学和软件⼯程的研究⼯作。⾯对项⽬中复杂的软件开 发任务,特别是模拟和操作系统的开发⼯作,他感受到了现有语⾔(如C语⾔)在表达能⼒、可…

无法定位程序输入点kernel32.dll ——一键修复丢失kernel32.dll方案

无法定位程序输入点" 错误通常发生在 Windows 操作系统中,当一个程序试图加载一个 DLL(动态链接库)文件中的特定函数,但无法找到该函数的入口点时。kernel32.dll 是 Windows 操作系统中的一个关键 DLL 文件,它包含…

设置DepthBufferBits和设置DepthStencilFormat的区别

1)设置DepthBufferBits和设置DepthStencilFormat的区别 2)Unity打包exe后,游戏内拉不起Steam的内购 3)Unity 2022以上Profiler.FlushMemoryCounters耗时要怎么关掉 4)用GoodSky资产包如何实现昼夜播发不同音乐功能 这是…

【北京迅为】《i.MX8MM嵌入式Linux开发指南》-第一篇 嵌入式Linux入门篇-第十八章 Linux编写第一个自己的命令

i.MX8MM处理器采用了先进的14LPCFinFET工艺,提供更快的速度和更高的电源效率;四核Cortex-A53,单核Cortex-M4,多达五个内核 ,主频高达1.8GHz,2G DDR4内存、8G EMMC存储。千兆工业级以太网、MIPI-DSI、USB HOST、WIFI/BT…

Python-找客户软件

软件功能 请求代码: 填充表格: 可以search全国各个区县的所有企业信息,过滤手机号、查看是否续存/在业状态。方便找客户。 支持定-制-其他引-留-阮*件(XHSS,DYY,KS,Bi-li*Bi-li) V*…

AutoHotKey自动热键(八)脚本快速暂停与重新加载

我们在编辑脚本的时候,可以添加快捷键来改变脚本的状态 ;暂停脚本 F11::Suspend;重置脚本 F12::Reloadreload用来重置脚本 我们可以在脚本开头加上标签提示脚本重启成功 ToolTip, 脚本已经重启 Sleep, 1000 ToolTip第二个ToolTip是用来关闭提示器用的 这个提示功能一定要写…

oracle dba常用脚本2

11、表空间实有、现有、使用情况查询对比 SELECT TABLESPACE_NAME 表空间,TO_CHAR(ROUND(BYTES / 1024, 2), 99990.00) || 实有,TO_CHAR(ROUND(FREE / 1024, 2), 99990.00) || G 现有,TO_CHAR(ROUND((BYTES - FREE) / 1024, 2), 99990.00) || G 使用,TO_CHAR(ROUND(10000 * US…

【开源合规】开源许可证风险场景详细解读

文章目录 前言关于BlackDuck许可证风险对比图弱互惠型许可证举个例子具体示例LGPL系列LGPL-2.0-onlyLGPL-2.0-or-laterLGPL-2.1-onlyLGPL-2.1-or-laterLGPL-3.0-onlyLGPL-3.0-or-laterMPL系列MPL-1.0MPL-1.1MPL-2.0EPL系列EPL-1.0EPL-2.0互惠型许可证GPL系列GPL-1.0GPL-2.0GPL-…

常用录屏软件,分享这四款宝藏软件!

在数字化时代,录屏软件已经成为我们日常工作、学习和娱乐中不可或缺的工具。无论你是需要录制教学视频、游戏过程,还是进行产品演示,一款高效、易用的录屏软件都能让你的工作事半功倍。今天,就为大家揭秘四款宝藏级录屏软件&#…

重磅|九科信息完成诺辉领投的B1轮融资,累计融资已达亿级

近日,九科信息宣布B1轮融资顺利完成。本轮由深圳诺辉岭南投资管理有限公司领投,深创投索斯福(深圳)私募创业投资基金跟投。 截至本轮,九科信息累计融资达亿级。但真正让九科人骄傲的,并非融资本身&#xff…

无法找到模块“@wangeditor/editor-for-vue”的声明文件

vue3项目中使用wangeditor/editor遇到的问题 开发环境不管红线报错正常使用 打包的时候就会报错了 1.安装依赖 pnpm install --save wangeditor/editor wangeditor/editor-for-vuenext 2.遇到的问题 3.解决方法 在src目录下面创建 wangeditor-types.d.ts 文件 代码如下 de…

The First项目报告:创新型金融生态Lista DAO

一、Lista DAO是什么? LISTA是Lista DAO的原生加密协议代币,设计为一种可互操作的实用代币,旨在促进去中心化金融(DeFi)领域内的支付、治理与激励。LISTA的诞生源于Lista DAO项目,该项目是一个基于BNB链的…

springboot3 集成GraalVM

目录 安装GraalVM 配置环境变量 Pom.xml 配置 build包 测试 安装GraalVM Download GraalVM 版本和JDK需要自己选择 配置环境变量 Jave_home 和 path 设置setting.xml <profile><id>graalvm-ce-dev</id><repositories><repository><id&…

2024最新版pycharm安装激火教程,附安装包+激huo马,Python教程,pycharm安装包!!

PyCharm的安装 PyCharm 是一个专门为 Python 开发者设计的 IDE&#xff0c;它同样具有代码导航、重构、调试和分析等功能。PyCharm 支持多种项目类型&#xff0c;如普通项目、Python 测试项目、Django 项目等&#xff0c;并提供了大量的内置模板和插件&#xff0c;以帮助您更快…

elementui实现复杂表单的实践

简介 文章主要讲述在vue3项目中使用elementui框架实现复杂表单的方式。表单中涉及动态组件的生成、文件上传和富文本编辑器的使用&#xff0c;只会将在实现过程中较复杂的部分进行分享&#xff0c;然后提供一份完整的前端代码。 表单效果演示 基础信息 spu属性 sku详情 关键…

融合CDN是什么?为什么需要融合CDN?其应用方法与原理是什么?

你了解融合CDN是什么吗&#xff1f;为什么需要融合CDN&#xff1f;你可能有听过融合CDN&#xff0c;但你知道它的应用方法与原理吗&#xff1f;本文将带你一次了解什么是融合CDN&#xff0c;详细介绍融合CDN的应用方法与运用原理&#xff0c;立刻替您解开心中疑惑&#xff01; …

[微信小程序知识点]自定义组件-拓展-外部样式类

使用组件时&#xff0c;组件使用者可以给组件传入css类名&#xff0c;通过传入的类名修改组件的样式 。 如果需要使用外部样式类修改组件的样式&#xff0c;在Component中需要用extemalClassess定义若干个外部样式类。 具体用法如下: (1)在Components文件里创建custom06组件 (…