深度学习之卷积神经网络入门

一、引言

在深度学习蓬勃发展的今天,卷积神经网络(Convolutional Neural Network,简称 CNN)凭借其在图像识别、计算机视觉等领域的卓越表现,成为了人工智能领域的核心技术之一。从手写数字识别到复杂的医学影像分析,从自动驾驶中的目标检测到智能安防的人脸识别,CNN 无处不在,深刻改变着我们的生活与工作方式。本文将深入剖析 CNN 的原理、结构组成,并通过实际案例展示其强大的应用能力。

二、原理

1、CNN 的核心思想是利用卷积运算来提取图像的特征。与传统的全连接神经网络不同,CNN 通过卷积层、池化层和激活函数等组件,能够自动学习图像中的局部特征和空间层次结构,从而更有效地处理图像数据。 

2、卷积层是 CNN 的核心组成部分,负责对输入图像进行特征提取。它通过卷积核与输入图像进行卷积运算,将图像与卷积核对应位置的元素相乘并求和,得到卷积结果。例如,一个 3×3 的卷积核在 6×6 的图像上进行步长为 1 的卷积操作,会生成一个 4×4 的特征图。卷积层中的参数主要包括卷积核的数量、大小、步长和填充方式,这些参数的设置会直接影响特征图的尺寸和提取到的特征类型。 

 

 3、激活函数层:为了引入非线性因素,使网络能够学习到复杂的函数关系,在卷积层之后通常会连接激活函数层。常见的激活函数有 ReLU(Rectified Linear Unit)、Sigmoid、Tanh 等。以 ReLU 函数为例,其公式为 f (x) = max (0, x),它能够有效缓解梯度消失问题,加快网络的训练速度,并且计算简单,在现代 CNN 模型中被广泛应用。

4、池化层:池化层的主要作用是对特征图进行下采样,降低数据的维度,减少计算量,同时还能增强模型的鲁棒性。常见的池化操作有最大池化(Max Pooling)和平均池化(Average Pooling)。最大池化会选取池化窗口内的最大值作为输出,能够保留最显著的特征;平均池化则计算池化窗口内的平均值,对特征进行平滑处理。例如,在一个 2×2 的最大池化窗口下,4×4 的特征图会被下采样为 2×2 的特征图。

 

5、全连接层:经过多层卷积和池化操作后,网络提取到的特征被展平并输入到全连接层。全连接层中的每个神经元都与上一层的所有神经元相连,它将提取到的特征进行整合,并通过激活函数进行非线性变换,最终输出分类结果或回归值。在图像分类任务中,全连接层的输出节点数量通常等于类别数,例如在 MNIST 手写数字识别任务中,全连接层的输出节点数为 10,分别对应 0 - 9 这 10 个数字类别。 

 

6、输出层:输出层根据具体的任务类型进行设计。在分类任务中,通常使用 Softmax 函数作为激活函数,将全连接层的输出转换为每个类别的概率分布,概率最大的类别即为预测结果;在回归任务中,输出层直接输出连续的数值

三、案例实现 

1、环境准备与数据加载

在开始之前,我们需要安装 PyTorch 和 torchvision。PyTorch 是一个强大的深度学习框架,而 torchvision 提供了许多与图像相关的数据集和工具。

import torch
from torch import nn   #导入神经网络模块,
from torch.utils.data import DataLoader   #数据包管理工具,打包数据,
from torchvision import datasets    #封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor    #数据转换,张量,将其他类型的数据转换为tensor张量

 2、下载MNIST数据集

'''下载训练数据集(包含训练图片+标签)'''
training_data=datasets.MNIST(root='data',           #表示下载的手写数字 到哪个路径。60000train=True,            #读取下载后的数据 中的 训练集download=True,         #如果你之前已经下载过了,就不用再下载transform=ToTensor(),  #张量,图片是不能直接传入神经网络模型
)    #对于pytorch库能够识别的数据一般是tensor张量。'''下载测试数据集(包含训练图片+标签)'''
test_data=datasets.MNIST(root='data',           #表示下载的手写数字 到哪个路径。60000train=False,           #读取下载后的数据中的训练集download=True,         #如果你之前已经下载过了,就不用再下载transform=ToTensor(),  #Tensor是在深度学习中提出并广泛应用的数据类型
)    #NumPy数组只能在CPU上运行。Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度
print(len(training_data))

3、数据可视化 

'''展示手写字图片,把训练数据集中的前59000张图片展示一下'''
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img,label=training_data[i+59000]#提取第59000张图片figure.add_subplot(3,3,i+1)#图像窗口中创建多个小窗口,小窗口用于显示图片plt.title(label)plt.axis("off") #  plt.show(I)#是示矢量,plt.imshow(img.squeeze(),cmap="gray") #plt.imshow()将Numpy数组data中的数据显示为图像,并在图形窗口中显贡a = img.squeeze()# img.squeeze()从张量img中去掉维度为1的。如果该维度的大小不为1则张量不会改变。#cmap="gray
plt.show()

4、创建数据加载器和配置设备 

"""创建数据DataLoader(数据加载器)
batch_size:将数据集分成多份,每一份为batch_size个数据
优点:可以减少内存的使用,提高训练速度。"""train_dataloader=DataLoader(training_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)for X, y in test_dataloader:#X时打包的的每一个数据包print("Shape of X [N, C, H, W]: {X.shape}")print(f"shape of y: {y.shape} {y.dtype}")break'''断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU。'''
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")  #字符串的格式化

 5、搭建神经网络模型 

'''定义神经网络 类的继承'''
class CNN(nn.Module):#类的名称def __init__ (self):   #python基础关于类,self类自已本身super(CNN,self).__init__()   #继承的父类初始化self.conv1 = nn.Sequential(    #将多个层组合成一起,创建了一个容器,将多个网络合在一起nn.Conv2d(       #2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据in_channels=1,   #图像通道个数,1表示灰度图(确定了卷积核 组中的个数)out_channels=16,   # 要得到多少个特征图,卷积核的个数kernel_size=5,     # 卷积核大小,5*5stride=1,          # 步长padding=2,        #一般希望卷积核处理后的结果大小与处理前的数据大小相同,效果会比较好。那padding改如何),                     # 输出的特征图为(16,28,28)nn.ReLU(),            # 激活函数,relu层,不会改变特征图的大小(16,28,28)nn.MaxPool2d(kernel_size=2),        #池化层,进行池化操作(2x2 区域),输出结果为:(16,14,14))self.conv2 = nn.Sequential(   #输入(16, 14, 14)nn.Conv2d(16,32,5,1,2),   # 输出(32,14,14)nn.ReLU(),     # (32*14*14)#nn.Conv2d(32, 32, 5, 1, 2),  # 输出(32,14,14)nn.ReLU(),         #(32 14 14)nn.MaxPool2d(2),     #输出(32,7,7))self.conv3 = nn.Sequential(      #输入(32 7 7)nn.Conv2d(32,64,5,1,2),   #(64,7,7)nn.ReLU(),)self.out=nn.Linear(64*7*6,10)    #全连接层得到的结果def forward(self,x):    #这里必须要写 forward是来自于父类nn里面的函数 要继承父类的功能x=self.conv1(x)x=self.conv2(x)x=self.conv3(x)  #输出(64,64,7,7)x=x.view(x.size(0),-1)#把x的数据变成2维的output=self.out(x)return outputmodel = CNN().to(device)#类的初始化完成,就会创建一个对象,model
print(model)

    定义了一个继承自nn.Module的CNN类,用于构建卷积神经网络模型。模型包含多个卷积层、激活函数层和池化层:

    conv1层:首先通过nn.Conv2d进行卷积操作,将输入的 1 通道图像转换为 16 个特征图;然后使用nn.ReLU激活函数引入非线性;最后通过nn.MaxPool2d进行最大池化操作,降低数据维度。

    conv2层:包含两个卷积层和激活函数层,进一步提取图像特征,并通过池化操作降低维度。

    conv3层:进行卷积和激活操作,继续提取更高级的特征。

    out层:全连接层,将卷积层输出的特征图展平后映射到 10 个类别(对应 0 - 9 这 10 个数字)。

    .forward方法定义了数据在模型中的前向传播过程,确保数据按照正确的顺序通过各个层。

     6、模型训练与测试

    def train(dataloader,model,loss_fn,optimizer):model.train()   #告诉模型,现在要进入训练模式,模型中w进行随机化操作,已经更新w。在训练过程中,w会被修改的
    #pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。
    #一般用法是:在训练开始之前写上model.trian(),在测试时写上 model.eval()batch_size_num=1for X,y in dataloader:       #其中batch为每一个数据的编号X,y=X.to(device),y.to(device)    #把训练数据集和标签传入cpu或GPUpred=model.forward(X)    #.forward可以被省略,父类中已经对次功能进行了设置。自动初始化loss=loss_fn(pred,y)     #通过交叉熵损失函数计算损失值loss# Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad()    #梯度值清零loss.backward()          #反向传播计算得到每个参数的梯度值woptimizer.step()         #根据梯度更新网络w参数loss_value=loss.item()   #从tensor数据中提取数据出来,tensor获取损失值if batch_size_num %100 ==0:print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')batch_size_num+=1def test(dataloader,model,loss_fn):size=len(dataloader.dataset)num_batches=len(dataloader)  #打包的数量model.eval()  #测试,w就不能再更新。test_loss,correct=0,0with torch.no_grad():    #一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)test_loss+=loss_fn(pred,y).item()   #test_loss是会自动累加每一个批次的损失值correct+=(pred.argmax(1)==y).type(torch.float).sum().item()a=(pred.argmax(1)==y)   #dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值b=(pred.argmax(1)==y).type(torch.float)test_loss /=num_batchescorrect /=sizeprint(f'Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}')

    7、定义损失函数和优化器 

    loss_fn=nn.CrossEntropyLoss()   #创建交叉熵损失函数对象,因为手写字识别中一共有10个数字,输出会有10个结果
    optimizer=torch.optim.Adam(model.parameters(),lr=0.01)   #创建一个优化器
    # #params:要训练的参数,一般我们传入的都是model.parameters()
    # lr:learning_rate学习率,也就是步长

    nn.CrossEntropyLoss是交叉熵损失函数,适用于多分类任务,用于计算模型预测结果与真实标签之间的差距。torch.optim.Adam是一种常用的优化器,用于更新模型的参数,以最小化损失函数。lr=0.01设置学习率,控制参数更新的步长。 

    6、模型训练与测试流程

    epoch=9
    for i in range(epoch):print(i+1)train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader,model,loss_fn)

    通过循环进行多个 epoch 的训练,每个 epoch 都会调用train函数对模型进行训练,训练完成后调用test函数对模型在测试集上的性能进行评估。随着训练的进行,可以观察到损失值逐渐降低,准确率逐渐提高,最终得到一个在 MNIST 数据集上表现良好的手写数字识别模型。

    四、总结

    本文详细介绍了利用 PyTorch 构建卷积神经网络实现 MNIST 手写数字识别的全过程。从数据集的准备、模型的构建,到训练和测试的各个环节,都进行了深入的代码解析和原理讲解。通过实践,我们可以看到卷积神经网络在图像识别任务中的强大能力,同时也掌握了 PyTorch 框架的基本使用方法。希望本文能够帮助读者更好地理解和应用卷积神经网络,在深度学习领域不断探索前进。 

    本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/78323.shtml

    如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

    相关文章

    使用RabbitMQ实现判题功能

    这次主要选用RabbitMQ消息队列来对判题服务和题目服务解耦,题目服务只需要向消息队列发送消息,判题服务从消息队列中取信息去执行判题,然后异步更新数据库即可。 五一宝宝请快点跑~~~~~ 先回顾一下RabbitMQ (1)引入依…

    HTML5后台管理界面开发

    HTML5后台管理界面开发 随着互联网技术的快速发展,后台管理系统在各个业务领域中扮演着越来越重要的角色。它不仅帮助企业管理数据、用户和业务流程,也为决策提供了依据。本文将介绍如何使用HTML5开发一个简单的后台管理界面,并结合代码示例…

    Oracle 11g RAC手动打补丁详细步骤

    备份: 节点1: root用户备份GI_home tar cvf Ghome_backup.tar /oracle/grid/crsoracle用户备份ORACLE_HOME tar cvf ohome_backup.tar $ORACLE_HOME节点2: root用户备份GI_home tar cvf Ghome_backup.tar /oracle/grid/crsoracle用户备份…

    xfce桌面汉化设置

    文章目录 汉化配置小结 汉化配置 检查当前语言环境,执行指令locale,如果输出的 LANG、LC_ALL 等未包含 zh_CN.UTF-8,需要设置中文环境。 安装中文语言包 sudo apt update sudo apt install language-pack-zh-hans language-pack-zh-hant设置…

    如何在IDEA中高效使用Test注解进行单元测试?

    在软件开发过程中,单元测试是保证代码质量的重要手段之一。而IntelliJ IDEA作为一款强大的Java开发工具,提供了丰富的功能来支持JUnit测试,尤其是通过Test注解可以快速编写和运行单元测试。那么,如何在IDEA中高效使用Test注解进行…

    Linux 路由

    Linux路由表 一:查看路由二:添加路由三:删除路由四:路由测试五:路由选择机制1.路由表2.路由匹配机制3.策略路由 示例1.多网卡分流2.VPN分流3.双默认路由负载均衡 一:查看路由 # 查看 main 表 ip route sho…

    x-cmd install | brows - 终端里的 GitHub Releases 浏览器,告别繁琐下载!

    目录 核心功能与优势安装适用场景 还在为寻找 GitHub 项目的特定 Release 版本而苦恼吗?还在网页上翻来覆去地查找下载链接吗?现在,有了 brows,一切都将变得简单高效! brows 是一款专为终端设计的 GitHub Releases 浏览…

    Vue多地址代理端口调用

    第一种方法 config.ts文件 配置多条代理服务端口 如下所示:proxy: {/app: {// 其他的端口target: http://125.124.5.117:12877/,changeOrigin: true}/api: {//默认的端口// http://192.168.31.53:5173/target: http://192.168.31.199:18777/,changeOrigin: true,rewrite: pat…

    青少年编程与数学 02-018 C++数据结构与算法 10课题、搜索[查找]

    青少年编程与数学 02-018 C数据结构与算法 10课题、搜索[查找] 一、线性搜索(Linear Search)原理实现步骤代码示例(C)复杂度分析优缺点 二、二分搜索(Binary Search)原理代码示例(C)…

    Linux操作系统从入门到实战(三)Linux基础指令(上)

    Linux操作系统从入门到实战(三)Linux基础指令(上) 前言一、ls 指令二、pwd三、cd四、touch 指令五、mkdir六、rmdir 指令和 rm 指令七、man 指令八、cp九、mv 指令十、cat 指令十一、 more 指令十二、less 指令十四、head 指令十五…

    Java对象转换的多种实现方式

    Java对象转换的多种实现方式 在Java开发中,对象转换是一个常见的需求。特别是在不同层次间传递数据时,通常需要将一个对象转换为另一个对象。虽然JSON序列化/反序列化是一种常见的方法,但在某些场景下可能并不是最佳选择。本文将总结几种常见…

    头歌实训之索引

    🌟 各位看官好,我是maomi_9526! 🌍 种一棵树最好是十年前,其次是现在! 🚀 今天来学习C语言的相关知识。 👍 如果觉得这篇文章有帮助,欢迎您一键三连,分享给更…

    Rundeck 介绍及安装:自动化调度与执行工具

    Rundeck介绍 概述:Rundeck 是什么? Rundeck 是一款开源的自动化调度和任务执行工具,专为运维场景设计,帮助工程师通过统一的平台管理和执行跨系统、跨节点的任务。它由 PagerDuty 维护(2016 年收购)&#…

    基于 Python 的自然语言处理系列(85):PPO 原理与实践

    📌 本文介绍如何在 RLHF(Reinforcement Learning with Human Feedback)中使用 PPO(Proximal Policy Optimization)算法对语言模型进行强化学习微调。 🔗 官方文档:trl PPOTrainer 一、引言&…

    珍爱网:从降本增效到绿色低碳,数字化新基建价值凸显

    2024年12月24日,法大大联合企业绿色发展研究院发布《2024签约减碳与低碳办公白皮书》,深入剖析电子签在推动企业绿色低碳转型中的关键作用,为企业实现环境、社会和治理(ESG)目标提供新思路。近期,法大大将陆…

    Java实现HTML转PDF(deepSeekAi->html->pdf)

    Java实现HTML转PDF,主要为了解决将ai返回的html文本数据转为PDF文件方便用户下载查看。 一、deepSeek-AI提问词 基于以上个人数据。总结个人身体信息,分析个人身体指标信息。再按一个月为维度,详细列举一个月内训练计划,维度详细至每周每天…

    Estimands与Intercurrent Events:临床试验与统计学核心框架

    1. Estimands(估计目标)概述 1.1 定义与作用 1.1.1 定义 Estimand是临床试验中需明确提出的科学问题,即研究者希望通过数据估计的“目标量”,定义“治疗效应”具体含义,确保分析结果与临床问题一致。 例如,在研究某种新药对高血压患者降压效果时,Estimand可定义为“在…

    Jsp技术入门指南【十】IDEA 开发环境下实现 MySQL 数据在 JSP 页面的可视化展示,实现前后端交互

    Jsp技术入门指南【十】IDEA 开发环境下实现 MySQL 数据在 JSP 页面的可视化展示,实现前后端交互 前言一、JDBC 核心接口和类:数据库连接的“工具箱”1. 常用的 2 个“关键类”2. 必须掌握的 5 个“核心接口” 二、创建 JDBC 程序的步骤1. 第一步&#xf…

    深入理解HotSpot JVM 基本原理

    关于JAVA Java编程语言是一种通用的、并发的、面向对象的语言。它的语法类似于C和C++,但它省略了许多使C和C++复杂、混乱和不安全的特性。 Java 是几乎所有类型的网络应用程序的基础,也是开发和提供嵌入式和移动应用程序、游戏、基于 Web 的内容和企业软件的全球标准。. 从…

    【HTTP/3:互联网通信的量子飞跃】

    HTTP/3:互联网通信的量子飞跃 如果说HTTP/1.1是乡村公路,HTTP/2是现代高速公路系统,那么HTTP/3就像是一种革命性的"传送门"技术,它彻底重写了数据传输的底层规则,让信息几乎可以瞬间抵达目的地,…