pytorch学习——如何构建一个神经网络——以手写数字识别为例

目录

一.概念介绍

1.1神经网络核心组件

1.2神经网络结构示意图

1.3使用pytorch构建神经网络的主要工具

二、实现手写数字识别

2.1环境

2.2主要步骤

2.3神经网络结构

2.4准备数据

2.4.1导入模块

2.4.2定义一些超参数

2.4.3下载数据并对数据进行预处理

2.4.4可视化数据集中部分元素

 2.4.5构建模型和实例化神经网络

2.4.6训练模型

2.4.7可视化损失函数

2.4.7.1 train  loss 

 2.4.7.2 test loss

一.概念介绍

        神经网络是一种计算模型,它模拟了人类神经系统的工作方式,由大量的神经元和它们之间的连接组成。每个神经元接收一些输入信息,并对这些信息进行处理,然后将结果传递给其他神经元。这些神经元之间的连接具有不同的权重,这些权重可以根据神经网络的训练数据进行调整。通过调整权重,神经网络可以对输入数据进行分类、回归、聚类等任务。

        通俗来讲,神经网络就是设置一堆参数,初始化这堆参数,然后通过求导,知道这些参数对结果的影响,然后调整这些参数的大小。直到参数大小可以接近完美地拟合实际结果。神经网络有两个部分:正向传播和反向传播。正向传播是求值,反向传播是求出参数对结果的影响,从而调整参数。所以,神经网络:正向传播->反向传播->正向传播->反向传播……     

        比如我们要预测一个图像是不是猫。如果是猫,它的结果就是1,如果不是猫,它的结果就是0.我们现在有一堆图片,有的是猫,有的不是猫,所以它对应的标签(这个是y)是:0 1 1 0 1。而我们的预测结果可能是对的,也可能是错的,假设我们的预测结果是:0 0 1 1 0.我们有3个预测对了,有2个预测错了。那么我们的损失值是2/5。当然这么搞的话太“粗糙”了,实际上我们会有一个函数来定义损失值是什么。而且我们的预测结果也不是一个确凿的数字,而是一个概率:比如我们预测第3张图片是猫的概率是0.8,那么我们的预测结果是0.8.总之,定义了损失值(这个损失值记为J)以后,我们要让这个损失值尽可能地小。

参考:什么是神经网络? - 绯红之刃的回答 - 知乎 

1.1神经网络核心组件

        神经网络看上去挺复杂,节点多,层多,参数多,但其结构都是类似的,核心部分和组件都是相通的,确定完这些核心组件,这个神经网络也就基本确定了。

核心组件包括:

(1)层:神经网络的基础数据结构是层,层是一个数据处理模块,它接受一个或多个张量作为输入,并输出一个或多个张量,由一组可调整参数描述。

(2)模型:模型是由多个层组成的网络,用于对输入数据进行分类、回归、聚类等任务。

 

(3)损失函数:参数学习的目标函数,通过最小化损失函数来学习各种参数。损失函数是衡量模型输出结果与真实标签之间的差异的函数,目标是最小化损失函数,提高模型性能。

(4)优化器:使损失函数的值最小化。根据损失函数的梯度更新神经网络中的权重和偏置,以使损失函数的值最小化,提高模型性能和稳定性。

1.2神经网络结构示意图

 描述:多个层链接在一起构成一个模型或网络,输入数据通过这个模型转换为预测值,然后损失函数把预测值与真实值进行比较,得到损失值(损失值可以是距离、概率值等),该损失值用于衡量预测值与目标结果的匹配或相似程度,优化器利用损失值更新权重参数,从而使损失值越来越小。这是一个循环过程,损失值达到一个阀值或循环次数到达指定次数,循环结束。

1.3使用pytorch构建神经网络的主要工具

 参考:第3章 Pytorch神经网络工具箱 | Python技术交流与分享

在PyTorch中,构建神经网络主要使用以下工具:

  1. torch.nn模块:提供了构建神经网络所需的各种层和模块,如全连接层、卷积层、池化层、循环神经网络等。

  2. torch.nn.functional模块:提供了一些常用的激活函数和损失函数,如ReLU、Sigmoid、CrossEntropyLoss等。

  3. torch.optim模块:提供了各种优化器,如SGD、Adam、RMSprop等,用于更新神经网络中的权重和偏置。

  4. torch.utils.data模块:提供了处理数据集的工具,如Dataset、DataLoader等,可以方便地处理数据集、进行批量训练等操作。

这些工具之间的相互关系如下:

  1. 使用torch.nn模块构建神经网络的各个层和模块。

  2. 使用torch.nn.functional模块中的激活函数和损失函数对神经网络进行非线性变换和优化。

  3. 使用torch.optim模块中的优化器对神经网络中的权重和偏置进行更新,以最小化损失函数。

  4. 使用torch.utils.data模块中的数据处理工具对数据集进行处理,方便地进行批量训练和数据预处理。

二、实现手写数字识别

2.1环境

        实例环境使用Pytorch1.0+,GPU或CPU,源数据集为MNIST。

2.2主要步骤

(1)利用Pytorch内置函数mnist下载数据
(2)利用torchvision对数据进行预处理,调用torch.utils建立一个数据迭代器
(3)可视化源数据
(4)利用nn工具箱构建神经网络模型
(5)实例化模型,并定义损失函数及优化器
(6)训练模型
(7)可视化结果

2.3神经网络结构

实验中使用两个隐含层,每层激活函数为Relu,最后使用torch.max(out,1)找出张量out最大值对应索引作为预测值。

2.4准备数据

2.4.1导入模块

import numpy as np
import torch
# 导入 pytorch 内置的 mnist 数据
from torchvision.datasets import mnist 
#导入预处理模块
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
#导入nn及优化器
import torch.nn.functional as F
import torch.optim as optim
from torch import nn

2.4.2定义一些超参数

# 定义训练和测试时的批处理大小
train_batch_size = 64
test_batch_size = 128# 定义学习率和迭代次数
learning_rate = 0.01
num_epoches = 20# 定义优化器的超参数
lr = 0.01
momentum = 0.5
#动量优化器通过引入动量参数(Momentum),在更新参数时考虑之前的梯度信息,可以使得参数更新方向更加稳定,同时加速梯度下降的收敛速度。动量参数通常设置在0.5到0.9之间,可以根据具体情况进行调整。

2.4.3下载数据并对数据进行预处理

#定义预处理函数,这些预处理依次放在Compose函数中。
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
#下载数据,并对数据进行预处理
train_dataset = mnist.MNIST('./data', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('./data', train=False, transform=transform)
#dataloader是一个可迭代对象,可以使用迭代器一样使用。
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

注:

①transforms.Compose可以把一些转换函数组合在一起;
②Normalize([0.5], [0.5])对张量进行归一化,这里两个0.5分别表示对张量进行归一化的全局平均值和方差。因图像是灰色的只有一个通道,如果有多个通道,需要有多个数字,如三个通道,应该是Normalize([m1,m2,m3], [n1,n2,n3])
③download参数控制是否需要下载,如果./data目录下已有MNIST,可选择False。
④用DataLoader得到生成器,这可节省内存。

2.4.4可视化数据集中部分元素

# 导入matplotlib.pyplot库,并设置inline模式
import matplotlib.pyplot as plt
%matplotlib inline# 枚举数据加载器中的一批数据
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)# 创建一个图像对象
fig = plt.figure()# 显示前6个图像和对应的标签
for i in range(6):plt.subplot(2,3,i+1)           # 将图像分成2行3列,当前位置为第i+1个plt.tight_layout()             # 自动调整子图之间的间距plt.imshow(example_data[i][0], cmap='gray', interpolation='none')  # 显示图像plt.title("Ground Truth: {}".format(example_targets[i]))          # 显示标签plt.xticks([])                 # 隐藏x轴刻度plt.yticks([])                 # 隐藏y轴刻度

注:

  1. 导入matplotlib.pyplot库,并设置inline模式,以在Jupyter Notebook中显示图像。

  2. 枚举数据加载器中的一批数据,其中test_loader是一个测试数据集加载器。

  3. 创建一个图像对象,用于显示图像和标签。

  4. 显示前6个图像和对应的标签,其中plt.subplot()用于将图像分成2行3列,plt.tight_layout()用于自动调整子图之间的间距,plt.imshow()用于显示图像,plt.title()用于显示标签,plt.xticks()和plt.yticks()用于隐藏x轴和y轴的刻度。

 2.4.5构建模型和实例化神经网络

class Net(nn.Module):"""使用sequential构建网络,Sequential()函数的功能是将网络的层组合到一起"""def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(Net, self).__init__()self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1),nn.BatchNorm1d(n_hidden_1))self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2),nn.BatchNorm1d(n_hidden_2))self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))def forward(self, x):x = F.relu(self.layer1(x))x = F.relu(self.layer2(x))x = self.layer3(x)return x#检测是否有可用的GPU,有则使用,否则使用CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#实例化网络
model = Net(28 * 28, 300, 100, 10)
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

2.4.6训练模型

# 开始训练
losses = []
acces = []
eval_losses = []
eval_acces = []for epoch in range(num_epoches):train_loss = 0train_acc = 0model.train()#动态修改参数学习率if epoch%5==0:optimizer.param_groups[0]['lr']*=0.1for img, label in train_loader:img=img.to(device)label = label.to(device)img = img.view(img.size(0), -1)# 前向传播out = model(img)loss = criterion(out, label)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 记录误差train_loss += loss.item()# 计算分类的准确率_, pred = out.max(1)num_correct = (pred == label).sum().item()acc = num_correct / img.shape[0]train_acc += acclosses.append(train_loss / len(train_loader))acces.append(train_acc / len(train_loader))# 在测试集上检验效果eval_loss = 0eval_acc = 0# 将模型改为预测模式model.eval()for img, label in test_loader:img=img.to(device)label = label.to(device)img = img.view(img.size(0), -1)out = model(img)loss = criterion(out, label)# 记录误差eval_loss += loss.item()# 记录准确率_, pred = out.max(1)num_correct = (pred == label).sum().item()acc = num_correct / img.shape[0]eval_acc += acceval_losses.append(eval_loss / len(test_loader))eval_acces.append(eval_acc / len(test_loader))print('epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'.format(epoch, train_loss / len(train_loader), train_acc / len(train_loader), eval_loss / len(test_loader), eval_acc / len(test_loader)))

2.4.7可视化损失函数

2.4.7.1 train  loss 

plt.title('train loss')
plt.plot(np.arange(len(losses)), losses)
plt.legend(['Train Loss'], loc='upper right')

 2.4.7.2 test loss

# 绘制测试集损失函数
plt.plot(eval_losses, label='Test Loss')
plt.title('Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/19766.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RocketMQ生产者和消费者都开启Message Trace后,Consume Message Trace没有消费轨迹

一、依赖 <dependency><groupId>org.apache.rocketmq</groupId><artifactId>rocketmq-spring-boot-starter</artifactId><version>2.0.3</version> </dependency>二、场景 1、生产者和消费者所属同一个程序 2、生产者开启消…

【css】css实现水平和垂直居中

通过 justify-content 和 align-items设置水平和垂直居中&#xff0c; justify-content 设置水平方向&#xff0c;align-items设置垂直方向。 代码&#xff1a; <style> .center {display: flex;justify-content: center;align-items: center;height: 200px;border: 3px…

【前端入门之旅】HTML中元素和标签有什么区别?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 标签&#xff08;Tag&#xff09;⭐元素&#xff08;Element&#xff09;⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅&a…

千“垂”百炼:垂直领域与语言模型

这一系列文章仍然坚持走“通俗理解”的风格&#xff0c;用尽量简短、简单、通俗的话来描述清楚每一件事情。本系列主要关注语言模型在垂直领域尝试的相关工作。 This series of articles still sticks to the "general understanding" style, describing everything…

网络安全--原型链污染

目录 1.什么是原型链污染 2.原型链三属性 1&#xff09;prototype 2)constructor 3)__proto__ 4&#xff09;原型链三属性之间关系 3.JavaScript原型链继承 1&#xff09;分析 2&#xff09;总结 3)运行结果 4.原型链污染简单实验 1&#xff09;实验一 2&#xff0…

微信小程序animation动画,微信小程序animation动画无限循环播放

需求是酱紫的&#xff1a; 页面顶部的喇叭通知&#xff0c;内容不固定&#xff0c;宽度不固定&#xff0c;就是做走马灯&#xff08;轮播&#xff09;效果&#xff0c;从左到右的走马灯&#xff08;轮播&#xff09;&#xff0c;每播放一遍暂停 1500ms &#xff5e; 2000ms 刚…

【ASP.NET MVC】MYSQL安装配置(4)

一、安装配置 1、下载MYSQL绿色版压缩包&#xff08;略&#xff09; 2、解压到目录&#xff0c;比如E:\mysql目录 3、设置环境变量 添加bin目录到path&#xff0c;方便运行Mysql的命令 先打开系统的《环境变量》配置 双击系统变量中的Path 添加Mysql的BIN目录到path: 4、在…

解决一个Yarn异常:Alerts for Timeline service 2.0 Reader

【背景】 环境是用Ambari搭建的大数据环境&#xff0c;版本是2.7.3&#xff0c;Hdp是3.1.0&#xff1b;我们用这一套组件搭建了好几个环境&#xff0c;都有这个异常告警&#xff0c;但hive、spark都运行正常&#xff0c;可以正常使用&#xff0c;所以也一直没有去费时间解决这…

jar命令的安装与使用

场景&#xff1a; 项目中经常遇到使用WinR软件替换jar包中的文件&#xff0c;有时候存在WinRAR解压替换时提示没有权限&#xff0c;此时winRAR不能用还有有什么方法替换jar包中的文件。 方法&#xff1a; 使用jar命令进行修改替换 问题&#xff1a; 执行jar命令报错jar 不…

ubuntu git操作记录设置ssh key

用到的命令&#xff1a; 安装git sudo apt-get install git配置git用户和邮箱 git config --global user.name “用户名” git config --global user.email “邮箱地址”安装ssh sudo apt-get install ssh然后查看安装状态&#xff1a; ps -e | grep sshd4. 查看有无ssh k…

一次web网页设计实践——checkbox单选、复选功能的实现

由于工作内容原因近期做了一个网页&#xff0c;记录下。 需求&#xff1a; 写一个如下的页面&#xff0c;包括checkbox单选&#xff0c;checkbox多选&#xff0c;slect&#xff0c;text等控件 内容&#xff1a; 一、checkbox &#xff08;Wlan 开关&#xff09; 要求&#x…

只需十四步,从零开始掌握Python机器学习

推荐阅读&#xff08;点击标题查看&#xff09; 1、Python 数据挖掘与机器学习实践技术应用 2、R-Meta分析与【文献计量分析、贝叶斯、机器学习等】多技术融合实践与拓展 3、最新基于MATLAB 2023a的机器学习、深度学习 4、【八天】“全面助力AI科研、教学与实践技能”夏令营…

python项目开发案例集锦,python项目案例代码

这篇文章主要介绍了python项目开发案例集锦(全彩版)&#xff0c;具有一定借鉴价值&#xff0c;需要的朋友可以参考下。希望大家阅读完这篇文章后大有收获&#xff0c;下面让小编带着大家一起了解一下。 前言 22个通过Python构建的项目&#xff0c;以此来学习Python编程。 ① 骰…

变透明的黑匣子:UCLA 开发可解释神经网络 SNN 预测山体滑坡

内容一览&#xff1a;由于涉及到多种时空变化因素&#xff0c;山体滑坡预测一直以来都非常困难。深度神经网络 (DNN) 可以提高预测准确性&#xff0c;但其本身并不具备可解释性。本文中&#xff0c;UCLA 研究人员引入了 SNN。SNN 具有完全可解释性、高准确性、高泛化能力和低模…

一元三次方程求解

一元三次方程求解 题目描述提示输入输出格式输入格式输出格式 输入输出样例输入样例输出样例 算法分析A C 代码 题目描述 有形如&#xff1a; a x 3 b x 2 c x d 0 ax^3bx^2c^xd0 ax3bx2cxd0一元三次方程。给出该方程中各项的系数 ( a a a&#xff0c; b b b&#xff0c;…

无限遍历,Python实现在多维嵌套字典、列表、元组的JSON中获取数据

目录 背景 思路 新建两个函数A和B&#xff0c;函数 A处理字典数据&#xff0c;被调用后&#xff0c;判断传递的参数&#xff0c;如果参数为字典&#xff0c;则调用自身&#xff1b; 如果是列表或者元组&#xff0c;则调用列表处理函数B&#xff1b; 函数 B处理列表&#x…

TabR:检索增强能否让深度学习在表格数据上超过梯度增强模型?

这是一篇7月新发布的论文&#xff0c;他提出了使用自然语言处理的检索增强Retrieval Augmented技术&#xff0c;目的是让深度学习在表格数据上超过梯度增强模型。 检索增强一直是NLP中研究的一个方向&#xff0c;但是引入了检索增强的表格深度学习模型在当前实现与非基于检索的…

MySQL的使用——【初识MySQL】第二节

MySQL的使用——【初识MySQL】第二节 文章目录 MySQL环境变量的配置&#xff08;如使用Navicat可忽略&#xff09;使用命令行连接MySQL&#xff08;如使用Navicat可忽略&#xff09;步骤注意 NavicatNavicat的下载Navicat的使用连接MySQL新建表 总结总结 MySQL环境变量的配置&a…

【秋招】算法岗的八股文之机器学习

目录 机器学习特征工程常见的计算模型总览线性回归模型与逻辑回归模型线性回归模型逻辑回归模型区别 朴素贝叶斯分类器模型 (Naive Bayes)决策树模型随机森林模型支持向量机模型 (Support Vector Machine)K近邻模型神经网络模型卷积神经网络&#xff08;CNN&#xff09;循环神经…

【雕爷学编程】MicroPython动手做(28)——物联网之Yeelight 3

知识点&#xff1a;什么是掌控板&#xff1f; 掌控板是一块普及STEAM创客教育、人工智能教育、机器人编程教育的开源智能硬件。它集成ESP-32高性能双核芯片&#xff0c;支持WiFi和蓝牙双模通信&#xff0c;可作为物联网节点&#xff0c;实现物联网应用。同时掌控板上集成了OLED…