23 经典卷积神经网络 LeNet【李沐动手学深度学习v2课程笔记】 (备注:提到如何把代码从CPU改到在GPU上使用)

目录

1. LeNet

2. 实现代码

3. 模型训练

4. 小结


本节将介绍LeNet,它是最早发布的卷积神经网络之一,因其在计算机视觉任务中的高效性能而受到广泛关注。 这个模型是由AT&T贝尔实验室的研究员Yann LeCun在1989年提出的(并以其命名),目的是识别图像 (LeCun et al., 1998)中的手写数字。 当时,Yann LeCun发表了第一篇通过反向传播成功训练卷积神经网络的研究,这项工作代表了十多年来神经网络研究开发的成果。

当时,LeNet取得了与支持向量机(support vector machines)性能相媲美的成果,成为监督学习的主流方法。 LeNet被广泛用于自动取款机(ATM)机中,帮助识别处理支票的数字。 时至今日,一些自动取款机仍在运行Yann LeCun和他的同事Leon Bottou在上世纪90年代写的代码呢!

1. LeNet

总体来看,LeNet(LeNet-5)由两个部分组成:

  • 卷积编码器:由两个卷积层组成;

  • 全连接层密集块:由三个全连接层组成。

图1 LeNet中的数据流。输入是手写数字,输出为10种可能结果的概率

LeNet是早期成功的神经网络:2卷积层+2池化层+3全连接层

先使用卷积层来学习图片空间信息

然后使用全连接层来转换到类别空间

2. 实现代码

每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层。请注意,虽然ReLU和最大汇聚层更有效,但它们在20世纪90年代还没有出现。每个卷积层使用5×5卷积核和一个sigmoid激活函数。这些层将输入映射到多个二维特征输出,通常同时增加通道的数量。

第一卷积层有6个输出通道,而第二个卷积层有16个输出通道。

每个2×2池操作(步幅2)通过空间下采样将维数减少4倍。卷积的输出形状由批量大小、通道数、高度、宽度决定。

为了将卷积块的输出传递给稠密块,我们必须在小批量中展平每个样本。换言之,我们将这个四维输入转换成全连接层所期望的二维输入。这里的二维表示的第一个维度索引小批量中的样本,第二个维度给出每个样本的平面向量表示。LeNet的稠密块有三个全连接层,分别有120、84和10个输出。因为我们在执行分类任务,所以输出层的10维对应于最后输出结果的数量。

通过下面的LeNet代码,可以看出用深度学习框架实现此类模型非常简单。我们只需要实例化一个Sequential块并将需要的层连接在一起。

卷积-激活-池化   卷积-激活-池化  全连接x3

import torch
from torch import nn
from d2l import torch as d2lclass Reshape(torch.nn.Module):def forward(self, x):return x.view(-1, 1, 28, 28)net = nn.Sequential(Reshape( ),nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

我们对原始模型做了一点小改动,去掉了最后一层的高斯激活。除此之外,这个网络与最初的LeNet-5一致。

下面,我们将一个大小为28×28的单通道(黑白)图像通过LeNet。通过在每一层打印输出的形状,我们可以检查模型,以确保其操作与我们期望的 图6.6.2一致。

图2 LeNet 的简化版

下面我们查看验证一下模型的每一层,观察输入在网络里面的变化

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)
Conv2d output shape:         torch.Size([1, 6, 28, 28])
Sigmoid output shape:        torch.Size([1, 6, 28, 28])
AvgPool2d output shape:      torch.Size([1, 6, 14, 14])
Conv2d output shape:         torch.Size([1, 16, 10, 10])
Sigmoid output shape:        torch.Size([1, 16, 10, 10])
AvgPool2d output shape:      torch.Size([1, 16, 5, 5])
Flatten output shape:        torch.Size([1, 400])
Linear output shape:         torch.Size([1, 120])
Sigmoid output shape:        torch.Size([1, 120])
Linear output shape:         torch.Size([1, 84])
Sigmoid output shape:        torch.Size([1, 84])
Linear output shape:         torch.Size([1, 10])

请注意,在整个卷积块中,与上一层相比,每一层特征的高度和宽度都减小了。 第一个卷积层使用2个像素的填充,来补偿5×5卷积核导致的特征减少。 相反,第二个卷积层没有填充,因此高度和宽度都减少了4个像素。

随着层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个。

同时,每个汇聚层的高度和宽度都减半。

最后,每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。

3. 模型训练

现在我们已经实现了LeNet,让我们看看LeNet在Fashion-MNIST数据集上的表现。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

虽然卷积神经网络的参数较少,但与深度的多层感知机相比,它们的计算成本仍然很高,因为每个参数都参与更多的乘法。 通过使用GPU,可以用它加快训练。

为了进行评估,我们需要对 3.6节中描述的evaluate_accuracy函数进行轻微的修改。 由于完整的数据集位于内存中,因此在模型使用GPU计算数据集之前,我们需要将其复制到显存中。

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

为了使用GPU,我们还需要一点小改动。 与 3.6节中定义的train_epoch_ch3不同,在进行正向和反向传播之前,我们需要将每一小批量数据移动到我们指定的设备(例如GPU)上。

如下所示,训练函数train_ch6也类似于 3.6节中定义的train_ch3。 由于我们将实现多层神经网络,因此我们将主要使用高级API。 以下训练函数假定从高级API创建的模型作为输入,并进行相应的优化。 我们使用在 4.8.2.2节中介绍的Xavier随机初始化模型参数。 与全连接层一样,我们使用交叉熵损失函数和小批量随机梯度下降。

#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

现在,我们训练和评估LeNet-5模型。

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.469, train acc 0.823, test acc 0.779
55296.6 examples/sec on cuda:0

4. 小结

  • 卷积神经网络(CNN)是一类使用卷积层的网络。

  • 在卷积神经网络中,我们组合使用卷积层、非线性激活函数和汇聚层。

  • 为了构造高性能的卷积神经网络,我们通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数。

  • 在传统的卷积神经网络中,卷积块编码得到的表征在输出之前需由一个或多个全连接层进行处理。

  • LeNet是最早发布的卷积神经网络之一。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/735041.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【AIGC调研系列】GitHub Copilot提高接口自动化效率的使用技巧

GitHub Copilot 提高接口自动化效率的使用技巧主要包括以下几点: 让Copilot学习你的代码:通过清晰的层次结构、细致的拆分、规范的目录和文件命名以及合理的代码抽取,可以帮助Copilot更好地理解你的编码风格和需求[2]。设置Prompt和使用Prompt Engineering:通过设置合适的P…

【2024年5月备考新增】《软考真题分章练习(答案解析) - 1 项目管理一般知识(高项)》

点击跳转本章无答案版 1、关于项目管理的描述,不正确的是:()。 A.项目管理的主要目的是实现企业管理目标 B.在项目管理中,时间是一种特殊的资源 C.项目管理的职能是对资源进行计划、组织、指挥、协调、控制 D.项目管理把各种知识、技能、手段和技术应用于项目活动项目管理…

HTML三识

表单标签 表单&#xff1a;在网页中主要负责数据采集功能&#xff0c;使用<form>标签定义表单表单项(元素)&#xff1a;不同类型的input元素&#xff0c;下拉列表&#xff0c;文本域等 标签描述<form>定义表单<input>定义表单项&#xff0c;通过type属性控…

VMware下载与安装

准备一个Linux的系统&#xff0c;成本最低的方式就是在本地安装一台虚拟机&#xff0c;VMware是业界最好用的虚拟机软件之一 官网&#xff1a;https://www.vmware.com/ 下载页面&#xff1a;https://www.vmware.com/products/workstation-pro/workstation-pro-evaluation.html …

C++之线程的简单操作

在我们启动程序的时候&#xff0c;如果我们没有开启多线程的情况下&#xff0c;默认我们的程序就是单线程&#xff0c;是串行执行的。 如果开启了多线程那么我们有的程序语句就可以并行执行了。单核情况下还是并发&#xff0c;多核情况下可以并行。 就记住一点&#xff0c;单线…

代码随想录算法训练营Day55 | 583.两个字符串的删除操作、72.编辑距离

583.两个字符串的删除操作 最开始想到的是基于最长公共子序列的写法&#xff1a;删除公共子序列以外的字符&#xff0c;两个字符串就相同了 int minDistance0(string word1, string word2) {int n word1.size();int m word2.size();vector<vector<int>> dp(n …

小文件问题及GlusterFS的瓶颈

01海量小文件存储的挑战 为了解决海量小文件的存储问题&#xff0c;必须采用分布式存储&#xff0c;目前分布式存储主要采用两种架构&#xff1a;集中式元数据管理架构和去中心化架构。 (1)集中式元数据架构&#xff1a; 典型的集中式元数据架构的分布式存储有GFS&#xff0…

【深度学习笔记】7_1 优化与深度学习

注&#xff1a;本文为《动手学深度学习》开源内容&#xff0c;部分标注了个人理解&#xff0c;仅为个人学习记录&#xff0c;无抄袭搬运意图 7.1 优化与深度学习 本节将讨论优化与深度学习的关系&#xff0c;以及优化在深度学习中的挑战。在一个深度学习问题中&#xff0c;我们…

2023年第三届中国高校大数据挑战赛(第一场)A题思路

竞赛时间 &#xff08;1&#xff09;报名时间&#xff1a;即日起至2024年3月8日 &#xff08;2&#xff09;比赛时间&#xff1a;2024年3月9日8:00至2024年3月12日20:00 &#xff08;3&#xff09;成绩公布&#xff1a;2024年4月30日前 赛题方向&#xff1a;大数据统计分析 …

202109CSPT4收集卡牌

题意&#xff1a;有 n n n种不同的卡牌,每一次抽卡获得第 i i i种卡牌的概率为 p i pi pi。如果这张卡牌已经获得&#xff0c;就会转化为一枚硬币。可以用 k k k枚硬币交换一张没有获得过的卡。求获得所有种类的牌的概率。 #include<bits/stdc.h> using namespace std; …

SQL 注入攻击 - delete注入

环境准备:构建完善的安全渗透测试环境:推荐工具、资源和下载链接_渗透测试靶机下载-CSDN博客 一、注入原理: 对于后台来说,delete操作通常是将对应的id传递到后台,然后后台会删除该id对应的数据。 如果后台没有对接收到的 id 参数进行充分的验证和过滤,恶意用户可能会…

数据治理实践——YY 直播业务指标治理实践

目录 一、问题背景 1.1 问题场景 1.2 问题小结 二、治理方案 2.1 治理目标 2.2 团队协同&#xff0c;共建规范 2.3 指标管理的定位 2.4 指标管理的目标及思路 2.5 指标管理&#xff0c;规范内容落地 2.6 数仓设计-关联指标维度 2.7 数据报表开发-配置口径说明 2.8 …

StableDiffusion3 官方blog论文研究

博客源地址&#xff1a;Stable Diffusion 3: Research Paper — Stability AI 论文源地址&#xff1a;https://arxiv.org/pdf/2403.03206.pdf Stability.AI 官方发布了Stable diffusion 3.0的论文研究&#xff0c;不过目前大家都沉浸在SORA带来的震撼中&#xff0c;所以这个水…

将props展平传到DOM上

当我们将展平&#xff08;传播&#xff09;的属性设置子组件时&#xff0c;我们会引入风险&#xff0c;因为我们可能会往 HTML 标签上添加它不支持的属性。 坏实践 下面这个例子会在 DOM 元素上增加一个该元素本身不支持的属性flag。 const Sample () > (<Spread flag…

chrome插件chrome.storage数据写入失败QUOTA_BYTES_PER_ITEM quota exceeded

Unchecked runtime.lastError while running storage.set: QUOTA_BYTES_PER_ITEM quota exceeded at Object.callback 在开发浏览器插件的时候&#xff0c;报错提示&#xff1a;超出存储限制&#xff0c;浏览器插件存储官方文档&#xff1a;https://developer.chrome.com/docs…

Linux篇面试题 2024

目录 Java全技术栈面试题合集地址Linux篇1.绝对路径用什么符号表示&#xff1f;当前目录、上层目录用什么表示&#xff1f;主目录用什么表示? 切换目录用什么命令&#xff1f;2.怎么查看当前进程&#xff1f;怎么执行退出&#xff1f;怎么查看当前路径&#xff1f;3.怎么清屏&…

Keras用tf的Strategy()分布式训练时候报XLA错误

We failed to lift variable creations out of this tf.function, so this tf.function cannot be run on XLA. A possible workaround is to move variable creation outside of the XLA compiled function. 最早用的pip -U 安装的keras没注意版本&#xff0c;直接可用。 之…

国外网站,similarweb统计网站流量的原理是什么? 这个统计数据是准确的吗? 有没有 没统计不到的?

标签&#xff1a; SimilarWeb&#xff1b; 国外网站 SimilarWeb 提供网站流量和其他互联网指标的估计&#xff0c;主要用于网站分析、竞争对手分析和市场研究。它如何工作&#xff0c;以及其数据的准确性&#xff0c;是很多人关心的问题。下面是对SimilarWeb统计网站流量原理的…

Golang-channel合集——源码阅读、工作流程、实现原理、已关闭channel收发操作、优雅的关闭等面试常见问题。

前言 面试被问到好几次“channel是如何实现的”&#xff0c;我只会说“啊&#xff0c;就一块内存空间传递数据呗”…所以这篇文章来深入学习一下Channel相关。从源码开始学习其组成、工作流程及一些常见考点。 NO&#xff01;共享内存 Golang的并发哲学是“要通过共享内存的…

【AI辅助研发】-开端:未来的编程范式

编程的四种范式 面向机器编程范式 面向机器编程范式是最原始的编程方式&#xff0c;它直接针对计算机硬件进行操作。程序员需要了解计算机的内部结构、指令集和内存管理等细节。在这种范式下&#xff0c;编程的主要目标是编写能够直接控制计算机硬件运行的机器代码。面向机器…