12.26 学习卷积神经网路(CNN)

完全是基于下面这个博客来进行学习的,感谢!

​​【深度学习基础】详解Pytorch搭建CNN卷积神经网络LeNet-5实现手写数字识别_pytorch cnn-CSDN博客

 基于深度神经网络DNN实现的手写数字识别,将灰度图像转换后的二维数组展平到一维,将一维的784个特征作为模型输入。在“展平”的过程中必然会失去一些图像的形状结构特征,因此基于DNN的实现方式并不能很好的利用图像的二维结构特征,而卷积神经网络CNN对于处理图像的位置信息具有一定的优势。因此卷积神经网络经常被用于图像识别/处理领域。

1、卷积层

内参数(卷积核本身)  

CNN中的卷积层和DNN中的全连接层是平级关系,在DNN中,我们训练的内参数是全连接层的权重w和偏置b,CNN也类似,CNN训练的是卷积核,也就相当于包含了权重和偏置两个内部参数。

 输入一个多维数据(上图为二维),与卷积核进行运算,即输入中与卷积核形状相同的部分,分别与卷积核进行逐个元素相乘再相加。例如计算结果中坐上角的15是根据如下过程计算得到的:

 

逐个元素相乘再相加,即:

1 * 2 + 2 * 0 + 3* 1 + 0 * 0 + 1 * 1 + 2 * 2 + 3 * 1 + 0 * 0 + 1 * 2 = 15 

外参数(填充和步幅)

   填充(padding)

   显然,只要卷积核的大小>1*1,必然会导致图像越卷越小,为了防止输入经过多个卷积层后变得过小,可以在发生卷积层之前,先向输入图像的外围填入固定的数据(比如0),这个步骤称之为填充,如下图:

 在使用Pytorch搭建卷积层的时候,需要在对应的接口中添加这个padding参数,向上图中这种情况,相当于在3*3的卷积核外围添加了“一圈”,则padding = 1,卷积层的接口中就要这样写:

nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3, paddding=1)

      参数in_channels和out_channels是对应于这个卷积层输入和输出的通道数参数,这里我们先放一放。

步幅(stride)

   步幅指的是使用卷积核的位置间隔,即输入中参与运算的那个范围每次移动的距离。前面几个示意图中的步幅均为1,即每次移动一格,如果设置stride=2,kernel_size=2,则效果如下:

     此时需要在卷积层接口中添加参数stride=2。

输入与尺寸的关系

 综上所述,结合外参数(步幅、填充)和内参数(卷积核),可以看出如下规律:

卷积核越大,输出越小。

步幅越大,输出越小。

填充越大,输出越大。

     用公式表示定量关系:

      如果输入和卷积核均为方阵,设输入尺寸为W*W,输出尺寸为N*N,卷积核尺寸为F*F,填充的圈数为P,步幅为S,则有关系:

N=\frac{W+2P-F}{S}+1

    这个关系大家要重点掌握,也可以自己推导一下,并不复杂。如果输入和卷积核不为方阵,设输入尺寸是H*W,输出尺寸是OH*OW,卷积核尺寸为FH*FW,填充为P,步幅为S,则输出尺寸OH*OW的计算公式是:

OH=\frac{H+2P-FH}{S}+1

OW=\frac{W+2P-FW}{S}+1

2、多通道问题

多通道输入

对于手写数字识别这种灰度图像,可以视为仅有(高*长)二维的输入。然而,对于彩色图像,每一个像素点都相当于是RGB的三个值的组合,因此对于彩色的图像输入,除了高*长两个维度外,还有第三个维度——通道,即红、绿、蓝三个通道,也可以视为3个单通道的二维图像的混合叠加。

当输入数据仅为二维时,卷积层的权重往往被称作卷积核(Kernel);

当输入数据为三维或更高时,卷积层的权重往往被称作滤波器(Filter)。

     对于多通道输入,输入数据和滤波器的通道数必须保持一致。这样会导致输出结果降维成二维,如下图:

对形状进行一下抽象,则输入数据C*H*W和滤波器C*FH*FW都是长方体,结果是一个长方形1*OH*OW,注意C,H,W是固定的顺序,通道数要写在最前。

多通道输出

 如果要实现多通道输出,那么就需要多个滤波器,让三维输入与多个滤波器进行卷积,就可以实现多通道输出,输出的通道数FN就是滤波器的个数FN,如下图:

 和单通道一样,卷积运算后也有偏置,如果进一步追加偏置,则结果如下:每个通道都有一个单独的偏置

3、池化层

池化,也叫汇聚(Pooling)。池化层通常位于卷积层之后(有时也可以不设置池化层),其作用仅仅是在一定范围内提取特征值,所以并不存在要学习的内部参数。池化仅仅对图像的高H和宽W进行特征提取,并不改变通道数C

一般有平均汇聚和最大值汇聚两种。

平均汇聚

如上图,池化的窗口大小为2*2,对应的步幅为2,因此对于上图这种情况,对应的Pytorch接口如下:

nn.AvgPool2d(kernel_size=2, stride=2) 

最大值汇聚

对应的Pytorch接口如下:

nn.MaxPool2d(kernel_size=2, stride=2) 

4、手写数字识别

详细内容见,我自己就在jupyter上写写代码【深度学习基础】详解Pytorch搭建CNN卷积神经网络LeNet-5实现手写数字识别_pytorch cnn-CSDN博客

任务描述

输入就相当于一个单通道的图像,是二维的。我们在实现的时候,要将每个样本图像转换为28*28的张量,作为输入

 因此对于整个手写数字识别的任务,模型的输入是一副图像,输出则是一个对应的识别出的数字(0-9之间的整数)。这里注意,在进行模型训练时,PyTorch会在整个过程中自动将输出转换为One-Hot编码,因此我们在训练时不需要手动将输出转换为One-Hot编码,但进行模型评估测试时,由于需要比对预测输出(One-hot)和真实输出(0-9的数字),要进行一次转化。

    对于单个样本,每个图像都是一个二维灰度图像,像素为28*28.二维灰度图像的通道数为1,因此可以将每个样本图像转换为28*28的张量,作为输入。相当于是下图这样的形式:

    具体怎么转换,需要用到torchvision库中的trabsforms进行图像转换,将数据集转换为张量的形式,并调整数据集的统计分布(转换为标准正态分布更利于训练)。

网络结构(LeNet-5)

LeNet-5起源于1998年,在手写数字识别上非常成功。其结构如下:

再列一个表格,具体结构如下:

注:输出层的激活函数目前已经被Softmax取代。

至于这些尺寸关系,我举个两例子吧:

   以第一层C1的输入和输出为例。输入尺寸W是28*28,卷积核F尺寸为5*5,步幅S为1,填充P为2,那么输出N的28*28怎么来的呢?按照公式如下:

N=\frac{W+2P-F}{S}+1=\frac{28+2*2-5}{1}+1=28

   我们也可以观察到第一层的卷积核个数为6,则输出的通道数也为6。

  再看一下第一个池化层S2,输入尺寸是28*28,卷积核F大小为2*2(此处的“卷积核”实际上指的是采样范围),步幅S=2,填充P为0,则输出的14*14是这么算出来的:

N=\frac{W+2P-F}{S}+1=\frac{28+2*0-2}{2}+1=14

import torch
from torch import nn
from net import MyLeNet5
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import os# 将图像转换为张量形式
data_transform = transforms.Compose([transforms.ToTensor()
])# 加载训练数据集
train_dataset = datasets.MNIST(root='D:\\Jupyter\\dataset\\minst',  # 下载路径train=True,   # 是训练集download=True,   # 如果该路径没有该数据集,则进行下载transform=data_transform   # 数据集转换参数
)# 批次加载器
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)# 加载测试数据集
test_dataset = datasets.MNIST(root='D:\\Jupyter\\dataset\\minst',  # 下载路径train=False,   # 是训练集download=True,   # 如果该路径没有该数据集,则进行下载transform=data_transform   # 数据集转换参数
)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)# 判断是否有gpu
device = "cuda" if torch.cuda.is_available() else "cpu"# 调用net,将模型数据转移到gpu
model = MyLeNet5().to(device)# 选择损失函数
loss_fn = nn.CrossEntropyLoss()    # 交叉熵损失函数,自带Softmax激活函数# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)# 学习率每隔10轮次, 变为原来的0.1
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)# 定于训练函数
def train(dataloader, model, loss_fn, optimizer):loss, current, n = 0.0, 0.0, 0for batch, (X, y) in enumerate(dataloader):# 前向传播X, y = X.to(device), y.to(device)output = model(X)cur_loss = loss_fn(output, y)_, pred = torch.max(output, dim=1)# 计算当前轮次时,训练集的精确度cur_acc = torch.sum(y == pred)/output.shape[0]# 反向传播optimizer.zero_grad()cur_loss.backward()optimizer.step()loss += cur_loss.item()current += cur_acc.item()n = n + 1print("train_loss: ", str(loss/n))print("train_acc: ", str(current/n))def test(dataloader, model, loss_fn):model.eval()loss, current, n = 0.0, 0.0, 0# 该局部关闭梯度计算功能,提高运算效率with torch.no_grad():for batch, (X, y) in enumerate(dataloader):# 前向传播X, y = X.to(device), y.to(device)output = model(X)cur_loss = loss_fn(output, y)_, pred = torch.max(output, dim=1)# 计算当前轮次时,训练集的精确度cur_acc = torch.sum(y == pred) / output.shape[0]loss += cur_loss.item()current += cur_acc.item()n = n + 1print("test_loss: ", str(loss / n))print("test_acc: ", str(current / n))return current/n    # 返回精确度# 开始训练
epoch = 50
max_acc = 0
for t in range(epoch):print(f"epoch{t+1}\n---------------")train(train_dataloader, model, loss_fn, optimizer)a = test(test_dataloader, model, loss_fn)# 保存最好的模型参数if a > max_acc:folder = 'save_model'if not os.path.exists(folder):os.mkdir(folder)max_acc = aprint("current best model acc = ", a)torch.save(model.state_dict(), 'save_model/best_model.pth')
print("Done!")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/890761.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity URP多光源支持,多光源阴影投射,多光源阴影接收(优化版)

目录 前言: 一、属性 二、SubShader 三、ForwardLitPass 定义Tags 声明变体 声明变量 定义结构体 顶点Shader 片元Shader 四、全代码 四、添加官方的LitShader代码 五、全代码 六、效果图 七、结语 前言: 哈喽啊,我又来啦。这…

如何使用React,透传各类组件能力/属性?

在23年的时候,我主要使用的框架还是Vue,当时写了一篇“如何二次封装一个Vue3组件库?”的文章,里面涉及了一些如何使用Vue透传组件能力的方法。在我24年接触React之后,我发现这种扩展组件能力的方式有一个专门的术语&am…

109.【C语言】数据结构之求二叉树的高度

目录 1.知识回顾:高度(也称深度) 2.分析 设计代码框架 返回左右子树高度较大的那个的写法一:if语句 返回左右子树高度较大的那个的写法二:三目操作符 3.代码 4.反思 问题 出问题的代码 改进后的代码 执行结果 1.知识回顾&#xf…

分析排名靠前的一些自媒体平台,如何运用这些平台?

众所周知,现在做网站越来越难了,主要的原因还是因为流量红利时代过去了。并且搜索引擎都在给自己的平台做闭环改造。搜索引擎的流量扶持太低了。如百度投资知乎,给知乎带来很多流量扶持,也为自身内容不足做一个填补。 而我们站长…

2024大模型在软件开发中的具体应用有哪些?(附实践资料合集)

大模型在软件开发中的具体应用非常广泛,以下是一些主要的应用领域: 自动化代码生成与智能编程助手: AI大模型能够根据开发者的自然语言描述自动生成代码,减少手动编写代码的工作量。例如,GitHub Copilot工具就是利用AI…

Ubuntu网络配置(桥接模式, nat模式, host主机模式)

windows上安装了vmware虚拟机, vmware虚拟机上运行着ubuntu系统。windows与虚拟机可以通过三种方式进行通信。分别是桥接模式;nat模式;host模式 一、桥接模式 所谓桥接模式,也就是虚拟机与宿主机处于同一个网段, 宿主机…

3.系统学习-熵与决策树

熵与决策树 前言1.从数学开始信息量(Information Content / Shannon information)信息熵(Information Entropy)条件熵信息增益 决策树认识2.基于信息增益的ID3决策树3.C4.5决策树算法C4.5决策树算法的介绍决策树C4.5算法的不足与思考 4. CART 树基尼指数(基尼不纯度…

SpringBoot + HttpSession 自定义生成sessionId

SpringBoot HttpSession 自定义生成sessionId 业务场景实现方案 业务场景 最近在做用户登录过程中,由于默认ID是通过UUID创建的,缺乏足够的安全性,决定要自定义生成 sessionId。 实现方案 正常的获取session方法如下: HttpSe…

破解海外业务困局:新加坡服务器托管与跨境组网策略

在当今全球化商业蓬勃发展的浪潮之下,众多企业将目光投向海外市场,力求拓展业务版图、抢占发展先机。而新加坡,凭借其卓越的地理位置、强劲的经济发展态势以及高度国际化的营商环境,已然成为企业海外布局的热门之选。此时&#xf…

数学课程评价系统:客户服务与教学支持

2.1 SSM框架介绍 本课题程序开发使用到的框架技术,英文名称缩写是SSM,在JavaWeb开发中使用的流行框架有SSH、SSM、SpringMVC等,作为一个课题程序采用SSH框架也可以,SSM框架也可以,SpringMVC也可以。SSH框架是属于重量级…

攻防世界web第三题file_include

<?php highlight_file(__FILE__);include("./check.php");if(isset($_GET[filename])){$filename $_GET[filename];include($filename);} ?>惯例&#xff1a; 代码审查&#xff1a; 1.可以看到include(“./check.php”);猜测是同级目录下有一个check.php文…

【深度学习环境】NVIDIA Driver、Cuda和Pytorch(centos9机器,要用到显示器)

文章目录 一 、Anaconda install二、 NIVIDIA driver install三、 Cuda install四、Pytorch install 一 、Anaconda install Step 1 Go to the official website: https://www.anaconda.com/download Input your email and submit. Step 2 Select your version, and click i…

寻找适合小户型的开源知识库open source knowledge base之路

寻找一个开源的知识库&#xff0c;为了把以前花很多时间收集的信息或是项目/课程资料放到一个容易归类和管理的私有自主系统中&#xff0c;以便更容易查阅&#xff0c;花更少时间收集、对比版本及分享等一系列管理工作&#xff0c;同时确保在需要时可以相对快速找到有用的资料&…

C语言勘破之路-最终篇 —— 预处理(上)

人无完人&#xff0c;持之以恒&#xff0c;方能见真我&#xff01;&#xff01;&#xff01; 共同进步&#xff01;&#xff01; 文章目录 一、预定义符号二、#define定义常量三.、#define定义宏四、带有副作用的宏参数五、宏替换的规则六、宏和函数的对比1.宏的优势2.函数的优…

学习threejs,THREE.RingGeometry 二维平面圆环几何体

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;threejs gis工程师 文章目录 一、&#x1f340;前言1.1 ☘️THREE.RingGeometry 圆环几…

Win11系统下Oracle11g数据库下载与安装使用教程

文章目录 一、Oracle下载与安装1.1 解压安装包1.2 开始安装Oracle11g1.2.1 用户 1.3 测试数据库是否配置成功1.4 了解一下 Oracle相关服务1.5 了解Oracle体系结构 二、使用工具连接数据库2.1 PL/ SQL 连接本地oracle 三、PL/ SQL远程访问数据库3.1 可能踩坑问题&#xff08;TNS…

数据结构(Java版)第六期:LinkedList与链表(一)

目录 一、链表 1.1. 链表的概念及结构 1.2. 链表的实现 专栏&#xff1a;数据结构(Java版) 个人主页&#xff1a;手握风云 一、链表 1.1. 链表的概念及结构 链表是⼀种物理存储结构上⾮连续存储结构&#xff0c;数据元素的逻辑顺序是通过链表中的引⽤链接次序实现的。与火车…

从零开始C++棋牌游戏开发之第三篇:游戏的界面布局设计

在游戏开发的旅途中&#xff0c;界面布局设计是一个充满创意和挑战的环节。对于棋牌类游戏而言&#xff0c;界面不仅仅是功能的载体&#xff0c;更是玩家与游戏互动的桥梁。一个清晰、直观且美观的界面可以显著提升游戏的用户体验。 在这篇文章中&#xff0c;我们将从功能需求…

计算机基础知识——数据结构与算法(五)(山东省大数据职称考试)

大数据分析应用-初级 第一部分 基础知识 一、大数据法律法规、政策文件、相关标准 二、计算机基础知识 三、信息化基础知识 四、密码学 五、大数据安全 六、数据库系统 七、数据仓库. 第二部分 专业知识 一、大数据技术与应用 二、大数据分析模型 三、数据科学 数据结构与算法…

数据库管理-第275期 Oracle 23ai:画了两张架构图(20241225)

数据库管理275期 2024-12-25 数据库管理-第275期 Oracle 23ai&#xff1a;画了两张架构图&#xff08;20241225&#xff09;1 系统管理分片2 用户定义分片总结 数据库管理-第275期 Oracle 23ai&#xff1a;画了两张架构图&#xff08;20241225&#xff09; 作者&#xff1a;胖…