通过PyTorch 手写数字识别 入门神经网络 详细讲解

通过PyTorch 手写数字识别 入门神经网络

数据集

在这里插入图片描述

MNIST数据集中有手写数字图片7万张,划分训练集6万张,划分测试集1万张。

每张图片都会有一张标签,也就是代表着图片的真实值(真实含义)。

概念 计算机是如何读取图片的呢?

把照片当作一个数列矩阵给计算机读取,将照片特征从右到左拼接成一列输入到网络层。

在这里插入图片描述

网络层

每一次的节点是由前一层计算得到的,a和b 分别代表系数(权重)和偏置项。

i表示前一层的节点序号,j表示当前节点的序号。

在机器学习和深度学习中,偏置项(Bias)是模型的一个重要组成部分。它是一个可学习的参数,通常用来调整模型的输出,使其能够更好地拟合训练数据。

在这里插入图片描述

最后图像的信息通过网络的传播一直传播到最后一层,k代表的是网络的层数。最后一层就是输出层,而最后有10个节点代表的分别是10个数字的可能性结果,每个节点对应一种可能。
在这里插入图片描述

归一化处理

因为每个节点代表的应该是概率,而每个节点的数值都应该是0<p<1,且每层的总和应该为1。

所以我们需要对节点做归一化处理

Softmax归一化

在多分类问题中,通常会使用softmax函数作为网络输出层的激活函数,softmax函数可以对输出值进行归一化操作,把所有输出值都转化为概率(0~1之间),所有概率值加起来等于1∶
在这里插入图片描述

例如:某个神经网络有3个输出值,为[1,5,3]。

在数学中有个数叫e(数学中一个常数,是一个无限不循环小数,且为超越数,其值约为2.718281828459045)

先计算出e1(e的1次方),e5,e3和它们的和的数值来,e1=2.718、e5=148.413,e3=20.086、e1+e5+e^3=171.217
1的概率:

在这里插入图片描述

3的概率:在这里插入图片描述
5的概率:
在这里插入图片描述

0.016+0.867+0.117=1

在这里插入图片描述

训练

现在我们的输出具有了概率这个概念,那么我们真正要使得我们的概率有意义,那么就需要进行"训练"!

我们一开始的概率分布是随机的,而我们这张图片代表是7,而理想状态下,这张图片是7的概率应该是百分百,而现实训练过程中与理想状态的差值便是损失loss。

所以为了减小损失,我们需要在训练过程中调整网络参数,也就是a和b,使得更接近与理想状态的预测判断概率。

调整网络参数的算法有很多,比如梯度下降算法,ADAM算法等等。从而神经网络问题就变成了一个最优化问题,在多次尝试下寻找到最优解。

在这里插入图片描述

而这仅仅是一张图片,假如我们对于上万张图片进行训练,从而调整得到合适的网络参数,便能使得我们的神经网络具备预测的能力,因为一次只输入一张图片,我们的效率会很低的,所以我们会分批次,几张图片一起输入到网络,这个批次的概念叫做batchSize。

在这里插入图片描述

激活函数

如果没有激活函数,观察我们的节点计算,我们会发现我们节点中的计算都是线性的,但我们生活中很多问题都不是线性的,输入和输出之间存在着非线性,因为一个线性函数无论怎么调整都调整不出非线性函数的效果(模拟出非线性行为),所以我们会在每一次计算中都套一层激活函数,从而达到非线性计算。

在这里插入图片描述

常见的激活函数如下:
在这里插入图片描述

在这个手写数字识别中,我们采用整流函数,因为当x小于0的时候都归0,x大于0的时候才会有数值,也就相当于激活的效果。

项目实现

安装库: pip install numpy torch torchvision matplotlib

pytorch GPU的安装方法更详细可以参考这篇文章:全网最详细的安装pytorch GPU方法,一次安装成功!!包括安装失败后的处理方法!-CSDN博客

首次运行会安装MNIST数据集。

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as pltclass Net(torch.nn.Module):def __init__(self):super().__init__()# 输入为28*28的像素图片,中间三层都放了64个节点self.fc1 = torch.nn.Linear(28 * 28, 64)self.fc2 = torch.nn.Linear(64, 64)self.fc3 = torch.nn.Linear(64, 64)self.fc4 = torch.nn.Linear(64, 10)def forward(self, x):# fc1全连接线性计算,再套上激活函数relux = torch.nn.functional.relu(self.fc1(x))x = torch.nn.functional.relu(self.fc2(x))x = torch.nn.functional.relu(self.fc3(x))# log_softmax softmax归一化再套上log让计算更稳定x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)return xdef get_data_loader(is_train):# 进行数据转换,tensor就是一个多维数组(又叫张量) 定义数据转换类型to_tensor = transforms.Compose([transforms.ToTensor()])# 下载MNIST数据集,""代表当前目录,is_train用于指定是导入训练集还是测试集data_set = MNIST("", is_train, transform=to_tensor, download=True)# batch_size 表示一个批次包含15张图片 ,shuffle表示数据是否是随机打乱的return DataLoader(data_set, batch_size=15, shuffle=True)# 用于评估神经网络的正确率
def evaluate(test_data, net):n_correct = 0n_total = 0with torch.no_grad():# 在测试集中按批次取出数据for (x, y) in test_data:# 计算神经网络的预测值 x代表图片, y代表真实结果(标签)outputs = net.forward(x.view(-1, 28 * 28))# 再与真实结果进行比较进行累加记录for i, output in enumerate(outputs):# argmax是找到一个数列中最大值的序号,也就是预测结果if torch.argmax(output) == y[i]:n_correct += 1n_total += 1return n_correct / n_totaldef main():# 导入训练集和测试集train_data = get_data_loader(is_train=True)test_data = get_data_loader(is_train=False)# 初始化神经网络net = Net()# 打印初始网络的正确率 一般是0.1,因为10种结果,猜对的概率是十分之一print("initial accuracy:", evaluate(test_data, net))optimizer = torch.optim.Adam(net.parameters(), lr=0.001)# epoch是训练轮次for epoch in range(2):# 这部分基本是通用写法for (x, y) in train_data:# 初始化net.zero_grad()# 正向传播output = net.forward(x.view(-1, 28 * 28))# 计算差值 nll_loss 是一个对数损失函数 是为了匹配前面的log_softmax的对数运算loss = torch.nn.functional.nll_loss(output, y)# 反向误差传播loss.backward()# 优化网络参数optimizer.step()# 每个epoch结束后打印一次正确率print("epoch", epoch, "accuracy:", evaluate(test_data, net))# 随机抽取三张图片验证模型性能for (n, (x, _)) in enumerate(test_data):if n > 3:breakpredict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))plt.figure(n)plt.imshow(x[0].view(28, 28))plt.title("prediction: " + str(int(predict)))plt.show()if __name__ == "__main__":main()

正向传播传播算出当前的概率或值,反向传播将计算得到的损失告诉网络层,从而进行优化调整。

epoch代表的是训练轮次,比如六万张图片,训练了两次六万张图片,那么就是训练两个epoch。

最终结果:在这里插入图片描述

训练一次的时候正确率就从0.09达到了0.95,第二次就到了0.96,提升就相对少了。

概念和语法问题

上下文管理器

with torch.no_grad(): 是一个上下文管理器,它确保在其控制下的代码块内不会执行梯度计算。在 PyTorch 中,当我们构建计算图时,默认情况下会对每个操作进行跟踪,以便能够计算梯度。这对于训练模型是必要的,因为我们需要通过反向传播来更新权重。然而,在模型的评估阶段或者当我们只需要前向传递来得到输出而不需要更新模型参数时,保持梯度计算是不必要的,甚至会消耗额外的内存和计算资源。

使用 torch.no_grad() 的好处包括:

  • 节省内存:不需要存储中间变量的梯度信息。
  • 提高性能:省去了梯度计算的时间。

当你看到 with torch.no_grad():,这意味着在这段代码执行期间,所有涉及到自动梯度计算的操作都将被忽略,即创建的 Tensor 不会被加入到计算图中。这对于评估模型、生成模型输出以及任何不需要梯度计算的任务都是有用的。

示例代码:

with torch.no_grad():# 在这里创建的所有 Tensor 和执行的所有操作都不会被记录在计算图中predictions = model(inputs)

在你的代码中,torch.no_grad() 被用在 evaluate 函数中,以确保在评估模型的准确率时不会进行不必要的梯度计算,从而提高效率并节约内存。这是因为评估阶段我们关心的是模型的性能而非更新模型参数。

全连接层(Fully Connected Layer,简称 FC 层)是神经网络中最基本的组件之一,也是最直观的一种层。在一个全连接层中,前一层的所有神经元(节点)都与后一层的所有神经元相连,也就是说,每一层的每一个神经元都会接受前一层所有神经元的输出作为输入。

全连接层的工作原理

在一个全连接层中,每个神经元的输出是由前一层所有神经元的输出经过加权求和后加上偏置项(bias),然后通过激活函数计算得出的。数学上,可以用以下公式来表示:

z=W⋅x+bz=Wx+b

其中:

  • zz 是神经元的加权输入(未经过激活函数之前的值)。
  • WW 是权重矩阵。
  • xx 是输入向量。
  • bb 是偏置项。

接着,zz 会通过一个激活函数(如 ReLU、Sigmoid、tanh 或者其他激活函数)来产生非线性映射:

h=f(z)h=f(z)

这里的 h 是最终的输出,f 是激活函数。

应用场景

全连接层通常用于处理一维的数据,例如从卷积层提取的特征向量或者是展平后的图像数据。在图像识别、自然语言处理等领域,全连接层常用于提取特征之后的分类任务。例如,在你的代码中,全连接层用于将输入图像(展平后的28x28像素,共784个元素)映射到一个较低维度的空间,最后输出类别概率分布。

优化器

optimizer = torch.optim.Adam(net.parameters(), lr=0.001) 这一行代码的作用是在 PyTorch 中创建一个优化器实例,用于更新神经网络的参数。让我们分解一下这条语句的含义:

代码解析

  1. torch.optim.Adam:这是 PyTorch 提供的一种优化算法——Adam(Adaptive Moment Estimation)算法的实现。Adam 是一种自适应学习率优化方法,它结合了动量(Momentum)和 RMSProp 的优点,能够在训练过程中动态调整每个参数的学习率。
  2. net.parameters():这是神经网络模型 net 的所有可学习参数的迭代器。这些参数通常是模型中的权重和偏置项,它们是训练过程中需要更新的对象。
  3. lr=0.001:这是学习率(Learning Rate)的设定值。学习率决定了参数更新的步长大小。较高的学习率会使参数更新更快,但也可能导致训练过程不稳定;较低的学习率则会使训练过程更稳定,但可能需要更多的时间来收敛。

作用

这句话的主要作用是创建一个 Adam 优化器,并指定要优化的参数集合以及学习率为 0.001。这个优化器将在训练过程中使用,具体来说:

  • 初始化优化器:创建一个 Adam 优化器实例,准备好对模型的参数进行优化。
  • 参数绑定:将模型的所有可学习参数传递给优化器,以便在训练过程中更新这些参数。
  • 设置学习率:确定了优化过程中每次参数更新的步长大小。

参考:

10分钟入门神经网络 PyTorch 手写数字识别_哔哩哔哩_bilibili

pytorch tutorial: PyTorch 手写数字识别 教程代码 (gitee.com)

秒懂Softmax归一化_矩阵列向量softmax归一化计算-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/54940.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

多态常见面试问题

1、什么是多态&#xff1f; 多态&#xff08;Polymorphism&#xff09;是面向对象编程中的一个重要概念&#xff0c;它允许同一个接口表现出不同的行为。在C中&#xff0c;多态性主要通过虚函数来实现&#xff0c;分为编译时多态&#xff08;静态多态&#xff09;和运行时多态…

【Spring AI】Java实现类似langchain的第三方函数调用_原理与详细示例

Spring AI 介绍 &#xff1a;简化Java AI开发的统一接口解决方案 在过去&#xff0c;使用Java开发AI应用时面临的主要困境是没有统一且标准的封装库&#xff0c;导致开发者需要针对不同的AI服务提供商分别学习和对接各自的API&#xff0c;这增加了开发难度与迁移成本。而Sprin…

【数据结构】邻接表

一、概念 邻接表是一个顺序存储与链式存储相结合的数据结构&#xff0c;用于描述一个图中所有节点之间的关系。 若是一个稠密图&#xff0c;我们可以选择使用邻接矩阵&#xff1b;但当图较稀疏时&#xff0c;邻接矩阵就显得比较浪费空间了&#xff0c;此时我们就可以换成邻接…

机器人的应用 基于5G的变电站智慧管控系统

背景概述 一、电力行业面临的挑战与变革 随着全球工业化和信息化的快速发展&#xff0c;电力行业作为国民经济的基础性行业&#xff0c;其重要性日益凸显。然而&#xff0c;随着电力网络的不断扩展和复杂化&#xff0c;变电站和开关站作为电力传输与分配的关键节点&#xff0…

Excel中Ctrl+e的用法

重点&#xff1a;想要使用ctrle&#xff0c;前提是整合或拆分后的结果放置的单元格必须和被提取信息的单元格相邻&#xff0c;且被提取信息的单元格也必须相连。 下图为错误示例 这样则可以使用ctrle 1、信息整合 2、提取信息 3、添加符号 4、信息顺序调换 5、数字提取 crtle还…

HarmonyOS NEXT 应用开发实战(三、ArkUI页面底部导航TabBar的实现)

在开发HarmonyOS NEXT应用时&#xff0c;TabBar是用户界面设计中不可或缺的一部分。本文将通过代码示例&#xff0c;带领大家一同实现一个常用的TabBar&#xff0c;涵盖三个主要的内容页&#xff1a;首页、知乎日报和我的页面。以模仿知乎日报的项目为背景驱动&#xff0c;设定…

解决ubuntu 下 VS code 无法打开点击没反应问题

从Ubuntu 22.04 升级到ubuntu 24.04 后&#xff0c;发现Vsode无法打开&#xff0c;不论是点击图标&#xff0c;还是terminator里面运行code 可执行程序&#xff0c;均没有反应。debug如下: 提示权限不够。 解决方案&#xff1a; sudo sysctl -w kernel.apparmor_restrict_unp…

C语言题目练习2

前面我们知道了单链表的结构及其一些数据操作&#xff0c;今天我们来看看有关于单链表的题目~ 移除链表元素 移除链表元素&#xff1a; https://leetcode.cn/problems/remove-linked-list-elements/description/ 这个题目要求我们删除链表中是指定数据的结点&#xff0c;最终返…

C语言 | Leetcode C语言题解之第460题LFU缓存

题目&#xff1a; 题解&#xff1a; /* 数值链表的节点定义。 */ typedef struct ValueListNode_s {int key;int value;int counter;struct ValueListNode_s *prev;struct ValueListNode_s *next; } ValueListNode;/* 计数链表的节点定义。 其中&#xff0c;head是数值链表的头…

腾讯云Android 与 iOS 相关

移动端&#xff08;Android/iOS&#xff09;支持哪几种系统音量模式&#xff1f; 支持2种系统音量类型&#xff0c;即通话音量类型和媒体音量类型&#xff1a; 通话音量&#xff1a;手机专门为通话场景设计的音量类型&#xff0c;使用手机自带的回声抵消功能&#xff0c;音质…

谷歌浏览器 文件下载提示网络错误

情况描述&#xff1a; 谷歌版本&#xff1a;129.0.6668.90 (正式版本) &#xff08;64 位&#xff09; (cohort: Control)其他浏览器&#xff0c;比如火狐没有问题&#xff0c;但是谷歌会下载失败&#xff0c;故推断为谷歌浏览器导致的问题小文件比如1、2M会成功&#xff0c;大…

【LeetCode】动态规划—95. 不同的二叉搜索树 II(附完整Python/C++代码)

动态规划—95. 不同的二叉搜索树 II 题目描述前言基本思路1. 问题定义二叉搜索树的性质&#xff1a; 2. 理解问题和递推关系递归构造思想&#xff1a;状态定义&#xff1a;递推公式&#xff1a;终止条件&#xff1a; 3. 解决方法递归 动态规划方法&#xff1a;伪代码&#xff…

如何使用vscode的launch.json来debug调试

1、创建一个launch.json文件 选择Python Debugger&#xff0c;再选择Python文件&#xff0c;创建处理如下 默认有下面五个参数 "name": "Python Debugger: Current File","type": "debugpy","request": "launch"…

金九银十软件测试面试题(800道)

今年你的目标是拿下大厂offer&#xff1f;还是多少万年薪&#xff1f;其实这些都离不开日积月累的过程。 为此我特意整理出一份&#xff08;超详细笔记/面试题&#xff09;它几乎涵盖了所有的测试开发技术栈&#xff0c;非常珍贵&#xff0c;人手一份 肝完进大厂 妥妥的&#…

【LeetCode】动态规划—123. 买卖股票的最佳时机 III(附完整Python/C++代码)

动态规划—123. 买卖股票的最佳时机 III 题目描述前言基本思路1. 问题定义2. 理解问题和递推关系状态定义&#xff1a;状态转移公式&#xff1a;初始条件&#xff1a; 3. 解决方法动态规划方法伪代码&#xff1a; 4. 进一步优化5. 小总结 Python代码Python代码解释 C代码C代码解…

Python基础之List列表用法

1、创建列表 names ["张三","李四","王五","Mary"] 2、列表分片 names[1]&#xff1a;获取数组的第2个元素。 names[1:3]&#xff1a;获取数组的第2、第3个元素。包含左侧&#xff0c;不包含右侧。 names[:3]等同于names[0:3]&…

List子接口

1.特点&#xff1a;有序&#xff0c;有下标&#xff0c;元素可以重复 2.方法&#xff1a;包含Collection中的所有方法&#xff0c;还包括自己的独有的方法&#xff08;API中查找&#xff09; 还有ListIterator&#xff08;迭代器&#xff09;&#xff0c;功能更强大。 包含更多…

机器学习/数据分析--用通俗语言讲解时间序列自回归(AR)模型,并用其预测天气,拟合度98%+

时间序列在回归预测的领域的重要性&#xff0c;不言而喻&#xff0c;在数学建模中使用及其频繁&#xff0c;但是你真的了解ARIMA、AR、MA么&#xff1f;ACF图你会看么&#xff1f;&#xff1f; 时间序列数据如何构造&#xff1f;&#xff1f;&#xff1f;&#xff0c;我打过不少…

读书笔记 - 虚拟化技术 - 0 QEMU/KVM概述与历史

《QEMU/KVM源码解析与应用》 - 王强 概述 虚拟化简介 虚拟化思想 David Wheeler&#xff1a;计算机科学中任何问题都可以通过增加一个中间层来解决。 虚拟化思想存在与计算机科学的各个领域。 主要思想&#xff1a;通过分层将底层的复杂&#xff0c;难用的资源虚拟抽象为简…

Spring Cloud 3.x 集成eureka快速入门Demo

1.什么是eureka&#xff1f; Eureka 由 Netflix 开发&#xff0c;是一种基于REST&#xff08;Representational State Transfer&#xff09;的服务&#xff0c;用于定位服务&#xff08;服务注册与发现&#xff09;&#xff0c;以实现中间层服务的负载均衡和故障转移&#xff…