23 经典卷积神经网络 LeNet【李沐动手学深度学习v2课程笔记】 (备注:提到如何把代码从CPU改到在GPU上使用)

目录

1. LeNet

2. 实现代码

3. 模型训练

4. 小结


本节将介绍LeNet,它是最早发布的卷积神经网络之一,因其在计算机视觉任务中的高效性能而受到广泛关注。 这个模型是由AT&T贝尔实验室的研究员Yann LeCun在1989年提出的(并以其命名),目的是识别图像 (LeCun et al., 1998)中的手写数字。 当时,Yann LeCun发表了第一篇通过反向传播成功训练卷积神经网络的研究,这项工作代表了十多年来神经网络研究开发的成果。

当时,LeNet取得了与支持向量机(support vector machines)性能相媲美的成果,成为监督学习的主流方法。 LeNet被广泛用于自动取款机(ATM)机中,帮助识别处理支票的数字。 时至今日,一些自动取款机仍在运行Yann LeCun和他的同事Leon Bottou在上世纪90年代写的代码呢!

1. LeNet

总体来看,LeNet(LeNet-5)由两个部分组成:

  • 卷积编码器:由两个卷积层组成;

  • 全连接层密集块:由三个全连接层组成。

图1 LeNet中的数据流。输入是手写数字,输出为10种可能结果的概率

LeNet是早期成功的神经网络:2卷积层+2池化层+3全连接层

先使用卷积层来学习图片空间信息

然后使用全连接层来转换到类别空间

2. 实现代码

每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层。请注意,虽然ReLU和最大汇聚层更有效,但它们在20世纪90年代还没有出现。每个卷积层使用5×5卷积核和一个sigmoid激活函数。这些层将输入映射到多个二维特征输出,通常同时增加通道的数量。

第一卷积层有6个输出通道,而第二个卷积层有16个输出通道。

每个2×2池操作(步幅2)通过空间下采样将维数减少4倍。卷积的输出形状由批量大小、通道数、高度、宽度决定。

为了将卷积块的输出传递给稠密块,我们必须在小批量中展平每个样本。换言之,我们将这个四维输入转换成全连接层所期望的二维输入。这里的二维表示的第一个维度索引小批量中的样本,第二个维度给出每个样本的平面向量表示。LeNet的稠密块有三个全连接层,分别有120、84和10个输出。因为我们在执行分类任务,所以输出层的10维对应于最后输出结果的数量。

通过下面的LeNet代码,可以看出用深度学习框架实现此类模型非常简单。我们只需要实例化一个Sequential块并将需要的层连接在一起。

卷积-激活-池化   卷积-激活-池化  全连接x3

import torch
from torch import nn
from d2l import torch as d2lclass Reshape(torch.nn.Module):def forward(self, x):return x.view(-1, 1, 28, 28)net = nn.Sequential(Reshape( ),nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

我们对原始模型做了一点小改动,去掉了最后一层的高斯激活。除此之外,这个网络与最初的LeNet-5一致。

下面,我们将一个大小为28×28的单通道(黑白)图像通过LeNet。通过在每一层打印输出的形状,我们可以检查模型,以确保其操作与我们期望的 图6.6.2一致。

图2 LeNet 的简化版

下面我们查看验证一下模型的每一层,观察输入在网络里面的变化

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)
Conv2d output shape:         torch.Size([1, 6, 28, 28])
Sigmoid output shape:        torch.Size([1, 6, 28, 28])
AvgPool2d output shape:      torch.Size([1, 6, 14, 14])
Conv2d output shape:         torch.Size([1, 16, 10, 10])
Sigmoid output shape:        torch.Size([1, 16, 10, 10])
AvgPool2d output shape:      torch.Size([1, 16, 5, 5])
Flatten output shape:        torch.Size([1, 400])
Linear output shape:         torch.Size([1, 120])
Sigmoid output shape:        torch.Size([1, 120])
Linear output shape:         torch.Size([1, 84])
Sigmoid output shape:        torch.Size([1, 84])
Linear output shape:         torch.Size([1, 10])

请注意,在整个卷积块中,与上一层相比,每一层特征的高度和宽度都减小了。 第一个卷积层使用2个像素的填充,来补偿5×5卷积核导致的特征减少。 相反,第二个卷积层没有填充,因此高度和宽度都减少了4个像素。

随着层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个。

同时,每个汇聚层的高度和宽度都减半。

最后,每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。

3. 模型训练

现在我们已经实现了LeNet,让我们看看LeNet在Fashion-MNIST数据集上的表现。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

虽然卷积神经网络的参数较少,但与深度的多层感知机相比,它们的计算成本仍然很高,因为每个参数都参与更多的乘法。 通过使用GPU,可以用它加快训练。

为了进行评估,我们需要对 3.6节中描述的evaluate_accuracy函数进行轻微的修改。 由于完整的数据集位于内存中,因此在模型使用GPU计算数据集之前,我们需要将其复制到显存中。

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

为了使用GPU,我们还需要一点小改动。 与 3.6节中定义的train_epoch_ch3不同,在进行正向和反向传播之前,我们需要将每一小批量数据移动到我们指定的设备(例如GPU)上。

如下所示,训练函数train_ch6也类似于 3.6节中定义的train_ch3。 由于我们将实现多层神经网络,因此我们将主要使用高级API。 以下训练函数假定从高级API创建的模型作为输入,并进行相应的优化。 我们使用在 4.8.2.2节中介绍的Xavier随机初始化模型参数。 与全连接层一样,我们使用交叉熵损失函数和小批量随机梯度下降。

#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

现在,我们训练和评估LeNet-5模型。

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.469, train acc 0.823, test acc 0.779
55296.6 examples/sec on cuda:0

4. 小结

  • 卷积神经网络(CNN)是一类使用卷积层的网络。

  • 在卷积神经网络中,我们组合使用卷积层、非线性激活函数和汇聚层。

  • 为了构造高性能的卷积神经网络,我们通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数。

  • 在传统的卷积神经网络中,卷积块编码得到的表征在输出之前需由一个或多个全连接层进行处理。

  • LeNet是最早发布的卷积神经网络之一。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/735041.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VMware下载与安装

准备一个Linux的系统,成本最低的方式就是在本地安装一台虚拟机,VMware是业界最好用的虚拟机软件之一 官网:https://www.vmware.com/ 下载页面:https://www.vmware.com/products/workstation-pro/workstation-pro-evaluation.html …

小文件问题及GlusterFS的瓶颈

01海量小文件存储的挑战 为了解决海量小文件的存储问题,必须采用分布式存储,目前分布式存储主要采用两种架构:集中式元数据管理架构和去中心化架构。 (1)集中式元数据架构: 典型的集中式元数据架构的分布式存储有GFS&#xff0…

【深度学习笔记】7_1 优化与深度学习

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图 7.1 优化与深度学习 本节将讨论优化与深度学习的关系,以及优化在深度学习中的挑战。在一个深度学习问题中,我们…

SQL 注入攻击 - delete注入

环境准备:构建完善的安全渗透测试环境:推荐工具、资源和下载链接_渗透测试靶机下载-CSDN博客 一、注入原理: 对于后台来说,delete操作通常是将对应的id传递到后台,然后后台会删除该id对应的数据。 如果后台没有对接收到的 id 参数进行充分的验证和过滤,恶意用户可能会…

数据治理实践——YY 直播业务指标治理实践

目录 一、问题背景 1.1 问题场景 1.2 问题小结 二、治理方案 2.1 治理目标 2.2 团队协同,共建规范 2.3 指标管理的定位 2.4 指标管理的目标及思路 2.5 指标管理,规范内容落地 2.6 数仓设计-关联指标维度 2.7 数据报表开发-配置口径说明 2.8 …

StableDiffusion3 官方blog论文研究

博客源地址:Stable Diffusion 3: Research Paper — Stability AI 论文源地址:https://arxiv.org/pdf/2403.03206.pdf Stability.AI 官方发布了Stable diffusion 3.0的论文研究,不过目前大家都沉浸在SORA带来的震撼中,所以这个水…

chrome插件chrome.storage数据写入失败QUOTA_BYTES_PER_ITEM quota exceeded

Unchecked runtime.lastError while running storage.set: QUOTA_BYTES_PER_ITEM quota exceeded at Object.callback 在开发浏览器插件的时候,报错提示:超出存储限制,浏览器插件存储官方文档:https://developer.chrome.com/docs…

Golang-channel合集——源码阅读、工作流程、实现原理、已关闭channel收发操作、优雅的关闭等面试常见问题。

前言 面试被问到好几次“channel是如何实现的”,我只会说“啊,就一块内存空间传递数据呗”…所以这篇文章来深入学习一下Channel相关。从源码开始学习其组成、工作流程及一些常见考点。 NO!共享内存 Golang的并发哲学是“要通过共享内存的…

【AI辅助研发】-开端:未来的编程范式

编程的四种范式 面向机器编程范式 面向机器编程范式是最原始的编程方式,它直接针对计算机硬件进行操作。程序员需要了解计算机的内部结构、指令集和内存管理等细节。在这种范式下,编程的主要目标是编写能够直接控制计算机硬件运行的机器代码。面向机器…

redis使用笔记

redis使用笔记 1、Redis简介1.1 含义1.2 功能1.3 特点 2. 常用的数据结构2.1 HASH 3 redis接口定义3.1 redisReply3.2 redisContext3.3 redisCommand 4 实践操作4.1 遇到问题4.1.1 Get哈希的时候返回error4.1.2 长度一直为0,str没法打印(未解决&#xff…

java正则表达式概述及案例

前言: 学习了正则表达式,记录下使用心得。打好基础,daydayup! 正则表达式 什么是正则表达式 正则表达式由一些特定的字符组成,代表一个规则。 正则表达式的功能 1:用来校验数据格式是否合规 2:在一段文本…

2024,互联网打工人最终没能逃得过 AI

时间很快就来到了三月份,回首看过去的一年,如果要选择最令人着迷的新技术,那非 ChatGPT 莫属。 从美国的硅谷、华尔街到中国的后厂村、中关村,几乎所有的科技大厂们都在讨论“AIGC”。 既 ChatGPT 之后,几乎每天都有…

【深度学习笔记】7_2 梯度下降和随机梯度下降

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图 7.2 梯度下降和随机梯度下降 在本节中,我们将介绍梯度下降(gradient descent)的工作原理。虽然梯度…

️网络爬虫与IP代理:双剑合璧,数据采集无障碍️

博主猫头虎的技术世界 🌟 欢迎来到猫头虎的博客 — 探索技术的无限可能! 专栏链接: 🔗 精选专栏: 《面试题大全》 — 面试准备的宝典!《IDEA开发秘籍》 — 提升你的IDEA技能!《100天精通鸿蒙》 …

day16_购物车(添加购物车,购物车列表查询,删除购物车商品,更新选中商品状态,完成购物车商品的全选,清空购物车)

文章目录 购物车模块1 需求说明2 环境搭建3 添加购物车3.1 需求说明3.2 远程调用接口开发3.2.1 ProductController3.2.2 ProductService 3.3 openFeign接口定义3.3.1 环境搭建3.3.2 接口定义3.3.3 降级类定义 3.4 业务后端接口开发3.4.1 添加依赖3.4.2 修改启动类3.4.3 CartInf…

基于springboot实现摄影网站系统项目【项目源码】

基于springboot实现摄影网站系统演示 摘要 随着时代的进步,社会生产力高速发展,新技术层出不穷信息量急剧膨胀,整个社会已成为信息化的社会人们对信息和数据的利用和处理已经进入自动化、网络化和社会化的阶段。如在查找情报资料、处理银行账…

invoke()到底是个什么方法???

调用jquery的方法返回属性值 1、invoke(‘val’) 在form的select下: cy.get(.action-select-multiple).select([apples, oranges, bananas])// when getting multiple values, invoke "val" method first jquery中val方法是用于返…

花店小程序有哪些功能 怎么制作

​花店小程序可以为花店提供一个全新的线上销售平台,帮助花店扩大市场份额,提升用户体验,增加销售额。下面我们来看看花店小程序应该具备哪些功能,以满足用户的需求。 1. 商品展示:展示花店的各类花卉和花束&#xff…

Vue.js数据绑定解密:深入探究v-model和v-bind的原理与应用

hello宝子们...我们是艾斯视觉擅长ui设计和前端开发10年经验!希望我的分享能帮助到您!如需帮助可以评论关注私信我们一起探讨!致敬感谢感恩! Vue.js数据绑定解密:深入探究v-model和v-bind的原理与应用 一、引言 Vue.…

Linux多线程之线程互斥

(。・∀・)ノ゙嗨!你好这里是ky233的主页:这里是ky233的主页,欢迎光临~https://blog.csdn.net/ky233?typeblog 点个关注不迷路⌯▾⌯ 目录 一、互斥 1.线程间的互斥相关背景概念 2.互…