计算机视觉——飞桨深度学习实战-深度学习网络模型

深度学习网络模型的整体架构主要数据集、模型组网以及学习优化过程三部分,本章主要围绕着深度学习网络模型的算法架构、常见模型展开了详细介绍,从经典的深度学习网络模型以CNN、RNN为代表,到为了解决显存不足、实时性不够等问题的轻量化网络设计,以及近年来卷各大计算机视觉任务的前沿网络模型Transformer和MLP。为了进一步剖析深度学习网络模型搭建的过程,最后以LeNet模型算法为例,在飞桨深度学习框架下进行了网络搭建案例展示。学完本章,希望读者能够掌握以下知识点:

  1. 了解经典的网络模型(CNN和RNN);
  2. 熟悉前沿的网络模型(Transformer和MLP);
  3. 掌握使用飞桨搭建深度学习网络模型-LeNet。

在前面的学习中,我们应该大概已经了解了国内外计算机视觉的近况和历史,还有深度学习算法基础,对于深度学习的框架应该也有了一个大概的了解,上面说的这些东西自己下去了解一下就好啦,也不会太难。这篇文章主要是针对深度学习的网络模型进行阐述。在了解经典网络模型的同时,也能了解前沿的网络模型,并且会将通过一个简单的例子让大家对于深度学习网络有一个大概的印象。

在深度学习开发框架的支持下,深度学习网络模型不断更新迭代,模型架构从经典的卷积神经网络CNN,循环神经网络RNN。发展到如今的Transformer,多层感知机MLP,而他们可以统一视为通过网络部件激活函数设定,优化策略等一系列操作来搭建深度学习网络模型,并且采用非线性复杂映射将原始数据转变为更高层次等抽象的表达。

深度学习网络模型的整体架构主要包含三部分,分别为数据集,模型组网以及学习优化过程。深度学习网络模型的训练过程即为优化过程,模型优化最直接的目的是通过多次迭代更新来寻找使得损失函数尽可能小的最优模型参数。通常神经网络的优化过程可以分为两个阶段,第一阶段是通过正向传播得到模型的预测值,并将预测值与正值标签进行比对,计算两者之间的差异作为损失值;第二个阶段是通过反向传播来计算损失函数对每个参数的梯度,根据预设的学习率和动量来更新每个参数的值。

总之,一个好的网络模型通常具有以下特点,一,模型易于训练及训练步骤简单且容易收敛,二,模型精度高及能够很好的把握数据的内在本质,可以提取到有用的关键特征,三,模型泛化能力强及模型不仅在已知数据上表现良好,还而且还能够在于已知数据分布一致的未知数据集上表现其鲁棒性。

案例:

一、任务介绍

手写数字识别(handwritten numeral recognition)是光学字符识别技术(optical character recognition,OCR)的一个分支,是初级研究者的入门基础,在现实产业中也占据着十分重要的地位,它主要的研究内容是如何利用电子计算机和图像分类技术自动识别人手写在纸张上的阿拉伯数字(0~9)。因此,本实验任务简易描述如图所示: 

二、模型原理

近年来,神经网络模型一直层出不穷,在各个计算机视觉任务中都呈现百花齐放的态势。为了让开发者更清楚地了解网络模型的搭建过程,以及为了在后续的各项视觉子任务实战中奠定基础。下面本节将以MNIST手写数字识别为例,在PaddlePaddle深度学习开发平台下构建一个LeNet网络模型并进行详细说明。LeNet是第一个将卷积神经网络推上计算机视觉舞台的算法模型,它由LeCun在1998年提出。在早期应用于手写数字图像识别任务。该模型采用顺序结构,主要包括7层(2个卷积层、2个池化层和3个全连接层),卷积层和池化层交替排列。以mnist手写数字分类为例构建一个LeNet-5模型。每个手写数字图片样本的宽与高均为28像素,样本标签值是0~9,代表0至9十个数字。

下面详细解析LeNet-5模型的网络结构及原理

图1 LeNet-5整体网络模型

(1)卷积层L1

L1层的输入数据形状大小为��×1×28×28Rm×1×28×28,表示样本批量为m,通道数量为1,行与列的大小都为28。L1层的输出数据形状大小为��×6×24×24Rm×6×24×24,表示样本批量为m,通道数量为6,行与列维都为24。

这里有两个问题很关键:一是,为什么通道数从1变成了6呢?原因是模型的卷积层L1设定了6个卷积核,每个卷积核都与输入数据发生运算,最终分别得到6组数据。二是,为什么行列大小从28变成了24呢?原因是每个卷积核的行维与列维都为5,卷积核(5×5)在输入数据(28×28)上移动,且每次移动步长为1,那么输出数据的行列大小分别为28-5+1=24。

(2)池化层L2

L2层的输入数据形状大小为��×6×24×24Rm×6×24×24,表示样本批量为m,通道数量为6,行与列的大小都为24。L2层的输出数据形状大小为��×6×12×12Rm×6×12×12,表示样本批量为m,通道数量为6,行与列维都为12。

在这里,为什么行列大小从24变成了12呢?原因是池化层中的过滤器形状大小为2×2,其在输入数据(24×24)上移动,且每次移动步长(跨距)为2,每次选择4个数(2×2)中最大值作为输出,那么输出数据的行列大小分别为24÷2=12。

(3)卷积层L3

L3层的输入数据形状大小为��×6×12×12Rm×6×12×12,表示样本批量为m,通道数量为6,行与列的大小都为12。L3层的输出数据形状大小为��×16×8×8Rm×16×8×8,表示样本批量为m,通道数量为16,行与列维都为8。

(4)池化层L4

L4层的输入数据形状大小为��×16×8×8Rm×16×8×8,表示样本批量为m,通道数量为16,行与列的大小都为8。L4层的输出数据形状大小为��×16×4×4Rm×16×4×4,表示样本批量为m,通道数量为16,行与列维都为4。池化层L4中的过滤器形状大小为2×2,其在输入数据(形状大小24×24)上移动,且每次移动步长(跨距)为2,每次选择4个数(形状大小2×2)中最大值作为输出。

(5)线性层L5

L5层输入数据形状大小为��×256Rm×256,表示样本批量为m,输入特征数量为256。输出数据形状大小为��×120Rm×120,表示样本批量为m,输出特征数量为120。

(6)线性层L6

L6层的输入数据形状大小为��×120Rm×120,表示样本批量为m,输入特征数量为120。L6层的输出数据形状大小为��×84Rm×84,表示样本批量为m,输出特征数量为84。

(7)线性层L7

L7层的输入数据形状大小为��×84Rm×84,表示样本批量为m,输入特征数量为84。L7层的输出数据形状大小为��×10Rm×10,表示样本批量为m,输出特征数量为10。

三、MNIST数据集

3.1 数据集介绍

手写数字分类数据集来源MNIST数据集,该数据集可以公开免费获取。该数据集中的训练集样本数量为60000个,测试集样本数量为10000个。每个样本均是由28×28像素组成的矩阵,每个像素点的值是标量,取值范围在0至255之间,可以认为该数据集的颜色通道数为1。数据分为图片和标签,图片是28*28的像素矩阵,标签为0~9共10个数字。

3.2 数据读取

(1)transform函数是对数据进行归一化和标准化

(2)train_dataset和test_dataset

paddle.vision.datasets.MNIST()中的mode='train'和mode='test'分别用于获取mnist训练集和测试集

#导入数据集Compose的作用是将用于数据集预处理的接口以列表的方式进行组合。
#导入数据集Normalize的作用是图像归一化处理,支持两种方式: 1. 用统一的均值和标准差值对图像的每个通道进行归一化处理; 2. 对每个通道指定不同的均值和标准差值进行归一化处理。
import paddle
from paddle.vision.transforms import Compose, Normalize
import os
import matplotlib.pyplot as plt
transform = Compose([Normalize(mean=[127.5],std=[127.5],data_format='CHW')])
# 使用transform对数据集做归一化
print('下载并加载训练数据')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
val_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('加载完成')

让我们一起看看数据集中的图片是什么样子的

train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
print(plt.imshow(train_data0, cmap=plt.cm.binary))
print('train_data0 的标签为: ' + str(train_label_0))
AxesImage(18,18;111.6x108.72)
train_data0 的标签为: [5]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingif isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead'a.item() instead', DeprecationWarning, stacklevel=1)

<Figure size 144x144 with 1 Axes>

让我们再来看看数据样子是什么样的吧

print(train_data0)

四、LeNet模型搭建

构建LeNet-5模型进行MNIST手写数字分类

#导入需要的包
import paddle
import paddle.nn.functional as F
from paddle.vision.transforms import Compose, Normalize#定义模型
class LeNetModel(paddle.nn.Layer):def __init__(self):super(LeNetModel, self).__init__()# 创建卷积和池化层块,每个卷积层后面接着2x2的池化层#卷积层L1self.conv1 = paddle.nn.Conv2D(in_channels=1,out_channels=6,kernel_size=5,stride=1)#池化层L2self.pool1 = paddle.nn.MaxPool2D(kernel_size=2,stride=2)#卷积层L3self.conv2 = paddle.nn.Conv2D(in_channels=6,out_channels=16,kernel_size=5,stride=1)#池化层L4self.pool2 = paddle.nn.MaxPool2D(kernel_size=2,stride=2)#线性层L5self.fc1=paddle.nn.Linear(256,120)#线性层L6self.fc2=paddle.nn.Linear(120,84)#线性层L7self.fc3=paddle.nn.Linear(84,10)#正向传播过程def forward(self, x):x = self.conv1(x)x = F.sigmoid(x)x = self.pool1(x)x = self.conv2(x)x = F.sigmoid(x)x = self.pool2(x)x = paddle.flatten(x, start_axis=1,stop_axis=-1)x = self.fc1(x)x = F.sigmoid(x)x = self.fc2(x)x = F.sigmoid(x)out = self.fc3(x)return outmodel=paddle.Model(LeNetModel())

五、模型优化过程

5.1 损失函数

由于是分类问题,我们选择交叉熵损失函数。交叉熵主要用于衡量估计值与真实值之间的差距。交叉熵值越小,模型预测效果越好。*

�(��,�^�)=−∑�=1������(�^��)E(yi,y^​i)=−∑j=1q​yji​ln(y^​ji​)

其中,��∈��yi∈Rq为真实值,���yji​是��yi中的元素(取值为0或1),�=1,...,�j=1,...,q。�^�∈��y^​i∈Rq是预测值(样本在每个类别上的概率)。其中,在paddle里面交叉熵损失对应的API是paddle.nn.CrossEntropyLoss()

5.2 参数优化

定义好了正向传播过程之后,接着随机化初始参数,然后便可以计算出每层的结果,每次将得到m×10的矩阵作为预测结果,其中m是小批量样本数。接下来进行反向传播过程,预测结果与真实结果之间肯定存在差异,以缩减该差异作为目标,计算模型参数梯度。进行多轮迭代,便可以优化模型,使得预测结果与真实结果之间更加接近。

六、模型训练与评估

训练配置:设定训练超参数

1、批大小batch_size设置为64,表示每次输入64张图片;

2、迭代次数epoch设置为5,表示训练5轮;

3、日志显示verbose=1,表示带进度条的输出日志信息。

model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),paddle.nn.CrossEntropyLoss(),paddle.metric.Accuracy())model.fit(train_dataset,epochs=5,batch_size=64,verbose=1)model.evaluate(val_dataset,verbose=1)
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/5
step  10/938 [..............................] - loss: 2.3076 - acc: 0.1062 - ETA: 21s - 23ms/step
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn (isinstance(seq, collections.Sequence) and
step  20/938 [..............................] - loss: 2.3023 - acc: 0.1023 - ETA: 18s - 20ms/step
step 938/938 [==============================] - loss: 0.1927 - acc: 0.7765 - 16ms/step         
Epoch 2/5
step 938/938 [==============================] - loss: 0.0913 - acc: 0.9584 - 17ms/step        
Epoch 3/5
step 938/938 [==============================] - loss: 0.0232 - acc: 0.9700 - 17ms/step         
Epoch 4/5
step 938/938 [==============================] - loss: 0.0057 - acc: 0.9763 - 18ms/step        
Epoch 5/5
step 938/938 [==============================] - loss: 0.0907 - acc: 0.9798 - 17ms/step         
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10000/10000 [==============================] - loss: 7.5607e-04 - acc: 0.9794 - 2ms/step         
Eval samples: 10000
{'loss': [0.00075607264], 'acc': 0.9794}

经过5个epoch世代迭代,LeNet5模型在MNIST图像分类任务上的准确度达到98%左右。

七、模型可视化

model.summary((1,1,28,28))
---------------------------------------------------------------------------Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================Conv2D-1       [[1, 1, 28, 28]]      [1, 6, 24, 24]          156      MaxPool2D-1     [[1, 6, 24, 24]]      [1, 6, 12, 12]           0       Conv2D-2       [[1, 6, 12, 12]]      [1, 16, 8, 8]          2,416     MaxPool2D-2     [[1, 16, 8, 8]]       [1, 16, 4, 4]            0       Linear-1          [[1, 256]]            [1, 120]           30,840     Linear-2          [[1, 120]]            [1, 84]            10,164     Linear-3          [[1, 84]]             [1, 10]              850      
===========================================================================
Total params: 44,426
Trainable params: 44,426
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 0.17
Estimated Total Size (MB): 0.22
---------------------------------------------------------------------------
{'total_params': 44426, 'trainable_params': 44426}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/96374.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[Linux]线程互斥

[Linux]线程互斥 文章目录 [Linux]线程互斥线程并发访问问题线程互斥控制--加锁pthread_mutex_init函数pthread_mutex_destroy函数pthread_mutex_lock函数pthread_mutex_unlock函数锁相关函数使用示例使用锁的细节加锁解锁的实现原理 线程安全概念常见的线程不安全的情况常见的…

岩土工程监测中无线振弦采集仪的高精度高稳定性的重要性

岩土工程监测中无线振弦采集仪的高精度高稳定性的重要性 岩土工程中&#xff0c;无线振弦采集仪是一种用于测量结构物振动情况的关键设备。该设备主要是为了监测结构物的破坏情况、安全性能、实时振动等相关参数的变化&#xff0c;以便于及时掌握结构物的变化情况&#xff0c;…

【数据结构】手撕归并排序(含非递归)

目录 一&#xff0c;归并排序&#xff08;递归&#xff09; 1&#xff0c;基本思想 2&#xff0c;思路实现 二&#xff0c;归并排序&#xff08;非递归&#xff09; 1&#xff0c;思路实现 2&#xff0c;归并排序的特性总结&#xff1a; 一&#xff0c;归并排序&#xff0…

面试题:在大型分布式系统中,给你一条 SQL,让你优化,你会怎么做?

亲爱的小伙伴们&#xff0c;大家好呀&#xff01;我是小米&#xff0c;一个热爱技术、乐于分享的90后程序猿。今天&#xff0c;我要和大家聊聊一个在大型分布式系统中非常有趣和挑战性的话题——如何优化 SQL 查询&#xff01; 这个问题可不简单&#xff0c;但不要担心&#x…

力扣第100题 相同的数 c++ 二叉 简单易懂+注释

题目 100. 相同的树 简单 给你两棵二叉树的根节点 p 和 q &#xff0c;编写一个函数来检验这两棵树是否相同。 如果两个树在结构上相同&#xff0c;并且节点具有相同的值&#xff0c;则认为它们是相同的。 示例 1&#xff1a; 输入&#xff1a;p [1,2,3], q [1,2,3] 输出…

除静电离子风嘴的工作原理及应用

除静电离子风嘴是一种常见的除静电设备&#xff0c;它的工作原理是通过产生大量的负离子来中和物体表面的静电电荷&#xff0c;从而达到除静电的目的。 除静电离子风嘴内部装有一个电离器&#xff0c;电离器会将空气中的氧气分子或水分子电离成正、负离子。这些带电的离子在空…

工信部教考中心:什么是《研发效能(DevOps)工程师》认证,拿到证书之后有什么作用!(上篇)丨IDCF

在计算机行业中&#xff0c;资质认证可以证明在该领域内的专业能力和知识水平。各种技术水平认证也是层出不穷&#xff0c;而考取具有公信力和权威性的认证是从业者的首选。同时&#xff0c;随着国内企业技术实力的提升和国家对于自主可控的重视程度不断提高&#xff0c;国产证…

铭控传感亮相2023国际物联网展,聚焦“多场景物联感知方案”应用

金秋九月&#xff0c;聚焦IoT基石技术&#xff0c;荟萃最全物联感知企业&#xff0c;齐聚IOTE 2023第20届国际物联网展深圳站。铭控传感携智慧楼宇&#xff0c;数字工厂&#xff0c;智慧消防&#xff0c;智慧泵房等多场景物联感知方案及多品类无线传感器闪亮登场&#xff0c;现…

做外贸独立站选Shopify还是WordPress?

现在确实会有很多新人想做独立站&#xff0c;毕竟跨境电商平台内卷严重&#xff0c;平台规则限制不断升级&#xff0c;脱离平台“绑架”布局独立站&#xff0c;才能获得更多流量、订单、塑造品牌价值。然而&#xff0c;在选择建立外贸独立站的过程中&#xff0c;选择适合的建站…

90、Redis 的 value 所支持的数据类型(String、List、Set、Zset、Hash)---->Hash 相关命令

本次讲解要点&#xff1a; Hash 相关命令&#xff1a;是指value中的数据类型 启动redis服务器&#xff1a; 打开小黑窗&#xff1a; C:\Users\JH>e: E:>cd E:\install\Redis6.0\Redis-x64-6.0.14\bin E:\install\Redis6.0\Redis-x64-6.0.14\bin>redis-server.exe red…

探索JavaScript事件流:DOM中的神奇旅程

&#x1f3ac; 江城开朗的豌豆&#xff1a;个人主页 &#x1f525; 个人专栏 :《 VUE 》 《 javaScript 》 ⛺️ 生活的理想&#xff0c;就是为了理想的生活 ! 目录 引言 1. 事件流的发展流程 1.1 传统的DOM0级事件 1.2 DOM2级事件和addEventListener方法 1.3 W3C DOM3级…

黑马mysql教程笔记(mysql8教程)基础篇——数据库相关概念、mysql安装及卸载、数据模型、SQL通用语法及分类(DDL、DML、DQL、DCL)

参考文章1&#xff1a;https://www.bilibili.com/video/BV1Kr4y1i7ru/ 参考文章2&#xff1a;https://dhc.pythonanywhere.com/article/public/1/ 文章目录 基础篇数据库相关概念&#xff08;数据库DataBase&#xff08;DB&#xff09;、数据库管理系统DataBase Management Sy…

解决Ubuntu18.04安装好搜狗输入法后无法打出中文的问题

首先下载安装 搜狗拼音输入法 &#xff0c;下载选择&#xff1a; x86_64 在ubuntu中设置 fcitx 最后发现安装好了&#xff0c;图标有了 &#xff0c;但是使用时不能输入中文&#xff0c;使用下面的命令解决&#xff1a; sudo apt install libqt5qml5 libqt5quick5 libqt5qu…

学习笔记|串口通信的基础知识|同步/异步|常见的串口软件的参数|STC32G单片机视频开发教程(冲哥)|第二十集:串口通信基础

目录 1.串口通信的基础知识串口通信(Serial Communication)同步/异步&#xff1f;全双工&#xff1f;常见的串口软件的参数 2.STC32的串口通信实现原理引脚选择模式选择 3.串口通信代码实现编写串口1通信程序测试 总结 1.串口通信的基础知识 百度百科&#xff1a;串口通信的概…

【dp】背包问题

背包问题 一、背包问题概述二、01背包问题&#xff08;1&#xff09;求这个背包至多能装多大价值的物品&#xff1f;&#xff08;2&#xff09;若背包恰好装满&#xff0c;求至多能装多大价值的物品&#xff1f; 三、完全背包问题&#xff08;1&#xff09;求这个背包至多能装多…

抄写Linux源码(Day19:读取硬盘前的准备工作有哪些?)

回忆我们需要做的事情&#xff1a; 为了支持 shell 程序的执行&#xff0c;我们需要提供&#xff1a; 1.缺页中断(不理解为什么要这个东西&#xff0c;只是闪客说需要&#xff0c;后边再说) 2.硬盘驱动、文件系统 (shell程序一开始是存放在磁盘里的&#xff0c;所以需要这两个东…

1.6.C++项目:仿muduo库实现并发服务器之channel模块的设计

项目完整版在&#xff1a; 文章目录 一、channel模块&#xff1a;事件管理Channel类实现二、提供的功能三、实现思想&#xff08;一&#xff09;功能&#xff08;二&#xff09;意义&#xff08;三&#xff09;功能设计 四、代码&#xff08;一&#xff09;框架&#xff08;二…

【Python从入门到进阶】38、selenium关于Chrome handless的基本使用

接上篇《37、selenium关于phantomjs的基本使用》 上一篇我们介绍了有关phantomjs的相关知识&#xff0c;但由于selenium已经放弃PhantomJS&#xff0c;本篇我们来学习Chrome的无头版浏览器Chrome Handless的使用。 一、Chrome Headless简介 Chrome Headless是一个无界面的浏览…

Kaggle - LLM Science Exam(二):Open Book QAdebertav3-large详解

文章目录 前言&#xff1a;优秀notebook介绍三、Open Book Q&A3.1 概述3.2 安装依赖&#xff0c;导入数据3.3 数据预处理3.3.1 处理prompt3.3.2 处理wiki数据 3.4 使用faiss搜索获取匹配的Prompt-Sentence Pairs3.5 查看context结果并保存3.6 推理3.6.1 加载测试集3.6.2 定…

FFmpeg 基础模块:AVIO、AVDictionary 与 AVOption

目录 AVIO AVDictionary 与 AVOption 小结 思考 我们了解了 AVFormat 中的 API 接口的功能&#xff0c;从实际操作经验看&#xff0c;这些接口是可以满足大多数音视频的 mux 与 demux&#xff0c;或者说 remux 场景的。但是除此之外&#xff0c;在日常使用 API 开发应用的时…