看完这篇我就不信还有人不懂卷积神经网络!

看完这篇我就不信还有人不懂卷积神经网络!

前言

在深度学习大🔥的当下,我知道介绍卷积神经网络的文章已经在全网泛滥,但我还是想要写出一点和别人不一样的东西,尽管要讲的知识翻来覆去还是那么一些,但我想尽可能做到极其通俗易懂,只要稍微有点计算机和线性代数基础的同学都能看懂。

水平有限,但凡有错误或理解不到位的地方,欢迎各位大佬指出🙏

什么是神经网络?

在介绍卷积神经网络之前,我们先回顾一下神经网络的基本知识📖。就目前而言,神经网络是深度学习算法的核心,我们所熟知的很多深度学习算法的背后其实都是神经网络。神经网络由节点层组成,通常包含一个输入层、一个或多个隐藏层和一个输出层,我们所说的几层神经网络通常指的是隐藏层的个数,因为输入层和输出层通常是固定的。节点之间相互连接并具有相关的权重和阈值。如果节点的输出高于指定的阈值,则激活该节点并将数据发送到网络的下一层。否则,没有数据被传递到网络的下一层。关于节点激活的过程有没有觉得非常相似?没错,这其实就是生物的神经元产生神经冲动的过程的模拟。

image.png

神经网络类型多样,适用于不同的应用场景。例如,循环神经网络(RNN)通常用于自然语言处理和语音识别,而卷积神经网络(ConvNetsCNN)则常用于分类和计算机视觉任务。在 CNN 产生之前,识别图像中的对象需要手动的特征提取方法。现在,卷积神经网络为图像分类和对象识别任务提供了一种更具可扩展性的方法,它利用线性代数的原理,特别是矩阵乘法,来识别图像中的模式。但是,CNN的计算要求很高,通常需要图形处理单元 (GPU) 来训练模型,否则训练速度很慢。

什么是卷积神经网络?

卷积神经网络是一种基于卷积计算的前馈神经网络。与普通的神经网络相比,它具有局部连接、权值共享等优点,使其学习的参数量大幅降低,且网络的收敛速度也更快。同时,卷积运算能更好地提取图像的特征。卷积神经网络的基本组件有卷积层、池化层和全连接层。卷积层是卷积网络的第一层,其后可以跟着其他卷积层或池化层,最后一层是全连接层。越往后的层识别图像越大的部分,较早的层专注于简单的特征,例如颜色和边缘。随着图像数据在 CNN 的各层中前进,它开始识别物体的较大元素或形状,直到最终识别出预期的物体。

image.png

下面将简单介绍这几种基本组件的相关原理和作用。

卷积层

卷积层是 CNN 的核心组件,其作用是提取样本的特征。它由三个部分组成,即输入数据、过滤器和特征图。在计算机内部,图像以像素矩阵的形式存储。若输入数据是一个RGB图像,则由 3D 像素矩阵组成,这意味着输入将具有三个维度——高度、宽度和深度。过滤器,也叫卷积核、特征检测器,其本质是一个二维(2-D)权重矩阵,它将在图像的感受野中移动,检查特征是否存在。

卷积核的大小不一,但通常是一个 3x3 的矩阵,这也决定了感受野的大小。不同卷积核提取的图像特征也不同。从输入图像的像素矩阵的左上角开始,卷积核的权重矩阵与像素矩阵的对应区域进行点积运算,然后移动卷积核,重复该过程,直到卷积核扫过整个图像。这个过程就叫做卷积。卷积运算的最终输出就称为特征图、激活图或卷积特征

image.png

如上图所示☝,特征图中的每个输出值不必连接到输入图像中的每个像素值,它只需要连接到应用过滤器的感受野。由于输出数组不需要直接映射到每个输入值,卷积层(以及池化层)通常被称为“部分连接”层。这种特性也被描述为局部连接

当卷积核在图像上移动时,其权重是保持不变的,这被称为权值共享。一些参数,如权重值,是在训练过程中通过反向传播和梯度下降的过程进行调整的。在开始训练神经网络之前,需要设置三个影响输出体积大小的超参数:

  1. 过滤器的数量 :影响输出的深度。例如,三个不同的过滤器将产生三个不同的特征图,从而产生三个深度。
  2. 步长(Stride) :卷积核在输入矩阵上移动的距离或像素数。虽然大于等于 2 的步长很少见,但较大的步长会产生较小的输出。
  3. 零填充(Zero-padding) :通常在当过滤器不适合输入图像时使用。这会将输入矩阵之外的所有元素设置为零,从而产生更大或相同大小的输出。有三种类型的填充:
    • Valid padding:这也称为无填充。在这种情况下,如果维度不对齐,则丢弃最后一个卷积。
    • Same padding:此填充确保输出层与输入层具有相同的大小。
    • Full padding:这种类型的填充通过在输入的边界添加零来增加输出的大小。

在每次卷积操作之后,CNN 的特征图经过激活函数(SigmoidReLULeaky ReLU等)的变换,从而对输出做非线性的映射,以提升网络的表达能力。

image.png

CNN 有多个卷积层时,后面的层可以看到前面层的感受野内的像素。例如,假设我们正在尝试确定图像是否包含自行车。我们可以把自行车视为零件的总和,即由车架、车把、车轮、踏板等组成。自行车的每个单独部分在神经网络中构成了一个较低级别的模式,各部分的组合则代表了一个更高级别的模式,这就在 CNN 中创建了一个特征层次结构。

image.png

池化层

为了减少特征图的参数量,提高计算速度,增大感受野,我们通常在卷积层之后加入池化层(也称降采样层)。池化能提高模型的容错能力,因为它能在不损失重要信息的前提下进行特征降维。这种降维的过程一方面使得模型更加关注全局特征而非局部特征,另一方面具有一定的防止过拟合的作用。池化的具体实现是将感受域中的值经过聚合函数后得到的结果作为输出。池化有两种主要的类型:

  • 最大池化(Max pooling):当过滤器在输入中移动时,它会选择具有最大值的像素发送到输出数组。与平均池化相比,这种方法的使用频率更高。
  • 平均池化(Average pooling):当过滤器在输入中移动时,它会计算感受域内的平均值以发送到输出数组。

img

全连接层

全连接层通常位于网络的末端,其结构正如其名。如前所述,输入图像的像素值在部分连接层中并不直接连接到输出层。但是,在全连接层中,输出层中的每个节点都直接连接到前一层中的节点,特征图在这里被展开成一维向量。

该层根据通过前几层提取的特征及其不同的过滤器执行分类任务。虽然卷积层和池化层倾向于使用 ReLu 函数,但 FC 层通常用 Softmax 激活函数对输入进行适当分类,产生 [0, 1] 之间的概率值。

自定义卷积神经网络进行手写数字识别

  • 导包
python复制代码import time
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoaderif torch.cuda.is_available():torch.backends.cudnn.deterministic = True
  • 设置参数 & 加载数据集
ini复制代码##########################
### SETTINGS
########################### Device
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")# Hyperparameters
random_seed = 1
learning_rate = 0.05
num_epochs = 10
batch_size = 128# Architecture
num_classes = 10##########################
### MNIST DATASET
########################### Note transforms.ToTensor() scales input images
# to 0-1 range
train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(),download=True)test_dataset = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# Checking the dataset
for images, labels in train_loader:  print('Image batch dimensions:', images.shape)print('Image label dimensions:', labels.shape)break
  • 模型定义
ini复制代码##########################
### MODEL
##########################class ConvNet(torch.nn.Module):def __init__(self, num_classes):super(ConvNet, self).__init__()# calculate same padding:# (w - k + 2*p)/s + 1 = o# => p = (s(o-1) - w + k)/2# 28x28x1 => 28x28x8self.conv_1 = torch.nn.Conv2d(in_channels=1,out_channels=8,kernel_size=(3, 3),stride=(1, 1),padding=1) # (1(28-1) - 28 + 3) / 2 = 1# 28x28x8 => 14x14x8self.pool_1 = torch.nn.MaxPool2d(kernel_size=(2, 2),stride=(2, 2),padding=0) # (2(14-1) - 28 + 2) = 0# 14x14x8 => 14x14x16self.conv_2 = torch.nn.Conv2d(in_channels=8,out_channels=16,kernel_size=(3, 3),stride=(1, 1),padding=1) # (1(14-1) - 14 + 3) / 2 = 1                 # 14x14x16 => 7x7x16                             self.pool_2 = torch.nn.MaxPool2d(kernel_size=(2, 2),stride=(2, 2),padding=0) # (2(7-1) - 14 + 2) = 0self.linear_1 = torch.nn.Linear(7*7*16, num_classes)for m in self.modules():if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):m.weight.data.normal_(0.0, 0.01)m.bias.data.zero_()if m.bias is not None:m.bias.detach().zero_()def forward(self, x):out = self.conv_1(x)out = F.relu(out)out = self.pool_1(out)out = self.conv_2(out)out = F.relu(out)out = self.pool_2(out)logits = self.linear_1(out.view(-1, 7*7*16))probas = F.softmax(logits, dim=1)return logits, probastorch.manual_seed(random_seed)
model = ConvNet(num_classes=num_classes)model = model.to(device)optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  
  • 模型训练
scss复制代码def compute_accuracy(model, data_loader):correct_pred, num_examples = 0, 0for features, targets in data_loader:features = features.to(device)targets = targets.to(device)logits, probas = model(features)_, predicted_labels = torch.max(probas, 1)num_examples += targets.size(0)correct_pred += (predicted_labels == targets).sum()return correct_pred.float()/num_examples * 100start_time = time.time()    
for epoch in range(num_epochs):model = model.train()for batch_idx, (features, targets) in enumerate(train_loader):features = features.to(device)targets = targets.to(device)### FORWARD AND BACK PROPlogits, probas = model(features)cost = F.cross_entropy(logits, targets)optimizer.zero_grad()cost.backward()### UPDATE MODEL PARAMETERSoptimizer.step()### LOGGINGif not batch_idx % 50:print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' %(epoch+1, num_epochs, batch_idx, len(train_loader), cost))model = model.eval()print('Epoch: %03d/%03d training accuracy: %.2f%%' % (epoch+1, num_epochs, compute_accuracy(model, train_loader)))print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
  • 模型评估
python复制代码with torch.set_grad_enabled(False): # save memory during inferenceprint('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))

输出如下👇

Test accuracy: 97.97%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/636568.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis原理篇(SkipList)

一.概述 本质是双端链表,只不过在正向遍历时可以不一个一个遍历,而是可以跳着遍历。 怎么实现的呢,下面是SkipList源码 二.源码 1. zskiplist 意义:跳表 zskiplist里面有头指针和尾指针,节点数量,最大…

【信号与系统】(1)连续和离散表示

在信号处理和数学中,连续和离散是两种基本的表示方法,用于描述信号、函数或数据集。 对连续信号 f(t)进行等间隔采样得到 连续表示(Continuous Representation) 连续表示通常用于描述在一个连续范围内变化的信号或函数。在连续…

Java学习(二十一)--JDBC/数据库连接池

为什么需要 传统JDBC数据库连接,使用DriverManager来获取; 每次向数据库建立连接时都要将Connection加载到内存中,再验证IP地址、用户名和密码(0.05s~1s)时间。 需要数据库连接时候,就向数据库要求一个&#xf…

JS-WebAPIS(四)

日期对象(常用) • 实例化 在代码中发现了 new 关键字时,一般将这个操作称为实例化创建一个时间对象并获取时间 获得当前时间 获得指定时间 • 时间对象方法 使用场景:因为日期对象返回的数据我们不能直接使用,所以…

【2023我的编程之旅】七次不同的计算机二级考试经历分享

目录 我报考过的科目 第一次报考MS Office 第二次报考Web语言,C语言,C语言 第三次报考C语言,C语言,Java语言 分享一些备考二级的方法 一些需要注意的细节 结语 2023年的CSDN征文活动已经进入了尾声,在这最后我…

Excel·VBA合并工作簿2

其他合并工作簿的方法,见之前的文章《ExcelVBA合并工作簿》 目录 8,合并文件夹下所有工作簿中所有工作表,按表头汇总举例 8,合并文件夹下所有工作簿中所有工作表,按表头汇总 与之前的文章《ExcelVBA合并工作簿&#x…

006.Oracle事务处理

我 的 个 人 主 页:👉👉 失心疯的个人主页 👈👈 入 门 教 程 推 荐 :👉👉 Python零基础入门教程合集 👈👈 虚 拟 环 境 搭 建 :👉&…

vue2 点击按钮下载文件保存到本地(后台返回的zip压缩流)

// import ./mock/index.js; // 该项目所有请求使用mockjs模拟 去掉mock页面url下载 console.log(res, res)//token 是使页面不用去登录了if (res.file) {window.location.href Vue.prototype.$config.VUE_APP_BASE_IDSWAPI Vue.prototype.$config.VUE_APP_IDSW /service/mode…

【Linux上创建一个LVM卷组,将多个物理卷添加到卷组中使用】

Linux上创建一个LVM卷组,将多个物理卷添加到卷组中使用 目录1.列出当前系统中所有的块设备信息,包括磁盘、分区、逻辑卷等2.对磁盘进行分区操作3.创建了一个名为 vg_data 的卷组4.将物理卷添加到已经存在的卷组5.在卷组中创建一个逻辑卷6.查看已创建的 L…

CodeWave智能开发平台--03--目标:应用创建--10初级采购管理系统总结

摘要 本文是网易数帆CodeWave智能开发平台系列的第14篇,主要介绍了基于CodeWave平台文档的新手入门进行学习,实现一个完整的应用,本文主要完成10初级采购管理系统总结 CodeWave智能开发平台的14次接触 CodeWave参考资源 网易数帆CodeWave…

【C++】string的基本使用

从这篇博客开始,我们的C部分就进入到了STL,STL的出现可以说是C发展历史上非常关键的一步,自此C和C语言有了较为明显的差别。那么什么是STL呢? 后来不断的演化,发展成了知名的两个版本,一个叫做P.J.版本&am…

【鸿蒙4.0】详解harmonyos开发语言ArkTS

文章目录 一.什么是ArkTS?1.ArkTS的背景2.了解js,ts,ArkTS的演变js(Javascript)Javascript的简介Javascript的特点 ts(Typescript)ArkTS 二. ArkTS的特点 一.什么是ArkTS? 1.ArkTS的背景 如官方文档所描述,ArkTS是基…

Ubuntu使用docker-compose安装mysql8或mysql5.7

ubuntu环境搭建专栏🔗点击跳转 Ubuntu系统环境搭建(十四)——使用docker-compose安装mysql8或mysql5.7 文章目录 Ubuntu系统环境搭建(十四)——使用docker-compose安装mysql8或mysql5.7MySQL81.新建文件夹2.创建docke…

【问题记录】Linux下克隆git项目到本地

1.出现远端克隆git上代码失败 (1)公钥有问题 linux下git生成公钥失败 解决方法: 删除.ssh下全部的文件,并重新设置用户名和邮箱再重新生成ssh公钥 (2)在询问是不是要把远端地址加入到konw_host中&#x…

【高等数学之牛莱公式】

一、深入挖掘定积分 二、变限积分 三、变限积分的"天然"连续性 四、微积分基本定理 五、定积分基本方法 5.1、换元法 5.2、分部积分法 六、定积分经典结论 七、区间再现公式 八、三角函数积分变换公式 九、周期函数积分变换公式 十、分段函数求定积分

python requests模块

目录 一:介绍 二:发送get请求 三:发送post请求 四:发送put请求 五:发送delele请求 六:响应信息 一:介绍 requests 是 Python 中的一个非常流行的 HTTP 客户端库,用于发送 HTTP…

Python 爬虫 之 抖音视频采集

嗨喽,大家好呀~这里是爱看美女的茜茜呐 知识点: 动态数据抓包 requests发送请求 开发环境: python 3.8 运行代码 pycharm 2022.3 辅助敲代码 requests pip install requests 如何安装python第三方模块: win R 输入 cmd 点击确定, 输入安装命令 pip install …

论文阅读笔记AI篇 —— Transformer模型理论+实战 (四)

论文阅读笔记AI篇 —— Transformer模型理论实战 (四) 一、理论1.1 理论研读1.2 什么是AI Agent? 二、实战2.1 先导知识2.1.1 tensor的创建与使用2.1.2 PyTorch的模块2.1.2.1 torch.nn.Module类的继承与使用2.1.2.2 torch.nn.Linear类 2.2 Transformer代…

k8s源码阅读:Informer源码解析

写在之前 Kubernetes的Informer机制是一种用于监控资源对象变化的机制。它提供了一种简化开发者编写控制器的方式,允许控制器能够及时感知并响应 Kubernetes 集群中资源对象的变化。Informer通过与Kubernetes API服务器进行交互,通过监听API服务器上资源…

Idea 开发环境不断切换git代码分支导致冲掉别人代码

问题分析 使用git reflog查看执行命令,以下是发生事故的切换和提交动作 46f72622e1 HEAD{41}: commit: feat: 【Sales - 6.3】小程序端不登录也可以录入客户线索 c5e7d9f6e1 HEAD{42}: fetch origin feature/20240102_Sales6.3_xingang:feature/20240102_Sales6.3…