【动手学深度学习】LeNet:卷积神经网络的开山之作

【动手学深度学习】LeNet:卷积神经网络的开山之作

  • 1,LeNet卷积神经网络简介
  • 2,Fashion-MNIST图像分类数据集
  • 3,LeNet总体架构
  • 4,LeNet代码实现
    • 4.1,定义LeNet模型
    • 4.2,定义模型评估函数
    • 4.3,定义训练函数进行训练


1,LeNet卷积神经网络简介

LeNet 是一种经典的卷积神经网络,是现代卷积神经网络的起源之一。它是早期成功的神经网络;LeNet先使用卷积层来学习图片空间信息,使用池化层降低图片敏感度,然后使用全连接层来转换到类别空间。 其思想被广泛应用于图像分类、目标检测、图像分割等多个计算机视觉领域,为这些领域的研究和发展提供了新的思路和方法。例如,在安防领域用于面部识别和监控系统,在自动驾驶领域用于实时视频分析和对象跟踪等。

1989年,Yann LeCun等人在贝尔实验室工作期间提出了LeNet-1。这个网络主要用于手写数字识别,引入了卷积操作和权值共享的概念,简化了网络结构,减少了参数数量,提高了模型的泛化能力和训练速度。此后经过多年的迭代改进,1998年,LeCun等人正式发表了LeNet-5。LeNet-5在LeNet-1的基础上进一步优化了网络结构,增加了网络的深度和复杂度,使其在手写数字识别任务上取得了更好的性能。LeNet-5的成功应用证明了CNN在图像识别领域的巨大潜力,为后续CNN的发展奠定了坚实的基础。


2,Fashion-MNIST图像分类数据集

Fashion-MNIST数据集是一个广泛使用的图像分类数据集。Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。

之前,已经学习过Fashion-MNIST数据集。 【动手学深度学习】Fashion-MNIST图片分类数据集,其基本情况如下:

  • 训练集:包含60,000张图像,用于模型训练;
  • 测试集:包含10,000张图像,用于评估模型性能;
  • 数据集由灰度图像组成,其通道数为1;
  • 每个图像的高度和宽度均为28像素;
  • 调用load_data_fashion_mnist()函数加载数据集;

具体定义如下:

"""
下载Fashion-MNIST数据集,然后将其加载到内存中
参数resize表示调整图片大小
"""
def load_data_fashion_mnist(batch_size, resize=None): # trans是一个用于转换的 *列表*trans = [transforms.ToTensor()]if resize:    # resize不为空,表示需要调整图片大小trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

3,LeNet总体架构

总体来看,LeNet(LeNet-5)由两个部分组成:

  • 卷积编码器:由两个卷积层组成;
  • 全连接层密集块:由三个全连接层组成;

在这里插入图片描述

每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层。(实际上使用ReLU激活函数和最大汇聚层更有效,但当时还没有发现):

  • Fashion-MNIST数据集的图像通道为1,大小为28×28,内部经过卷积层填充之后得到的实际输入数据是32×32的图像数据

  • 第一卷积层有6个输出通道,而第二个卷积层有16个输出通道;

  • 对应输出通道的数量,第一个卷积层有6个5×5的卷积核,第二个卷积层有16个5×5的卷积核;

  • 每个卷积核应用于输入数据时会产生一个特征图(feature map),也就是一个输出通道;

  • 每个卷积层都使用不同数量的5×5的卷积核和一个sigmoid激活函数。这些层将输入映射到多个二维特征输出,通常同时增加通道的数量;

  • 卷积操作后,通过2×2的池化操作默认步幅为2和池化窗口大小保持一致)将原特征图的各维度减半。比如原来是28×28,池化后变为14×14;


4,LeNet代码实现

接下来使用深度学习框架实现LeNet模型,并进行训练和测试。


4.1,定义LeNet模型

LeNet模型总共七层: 两层卷积层、两层池化层、三层全连接层; 其中每层都使用sigmod作为激活函数,它将卷积层的输出压缩到0和1之间,有助于非线性变换。

import torch
from torch import nn
from d2l import torch as d2l
""" 默认情况下,深度学习框架中的步幅与汇聚窗口的大小相同(窗口没有重叠)"""# nn.Sequential 是一个容器,可按顺序包装一系列子模块(如层、激活函数)。使得模型的构建变得更加简洁
net = nn.Sequential(# 第一个二维卷积层,输入通道是1(灰度图像),输出通道是6,卷积核大小5×5,图像周围加入两层0填充# 使用sigmod激活函数nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),# 第一个平均池化层:用2x2的池化窗口,步长为2。经此池化操作后得6个14×14的特征图nn.AvgPool2d(kernel_size=2, stride=2),# 这是第二个二维卷积层,输入通道数为6(与第一个卷积层的输出通道数相匹配),输出通道数为16。卷积核的大小为5x5,没有使用padding填充# 使用sigmod激活函数nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),# 第二个平均池化层:配置与第一层平均池化层相同。nn.AvgPool2d(kernel_size=2, stride=2),# 在将数据传递给全连接层之前,需要将多维的卷积和池化输出展平为一维向量。以便传给全连接层nn.Flatten(),# 经过前面的卷积和池化操作后,输出16个5×5的特征图# 全连接层,输入特征的数量是16 * 5 * 5。输出特征的数量是120。nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),# 全连接层,输入特征数量120,输出84nn.Linear(120, 84), nn.Sigmoid(),# 全连接层,输入特征数量84,输出10,对应Fashion-MNIST数据集的10个类别nn.Linear(84, 10))

下面,我们将一个大小为 28 × 28 28 \times 28 28×28的单通道(黑白)图像通过LeNet。通过在每一层打印输出的形状,我们可以检查模型,以确保其操作与我们期望的一致。

# 打印调试信息,检查模型
# size=(1, 1, 28, 28):批次大小1,通道数1,形状28*28 
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)# 遍历了神经网络 net 中的每一层
for layer in net:X = layer(X)# 打印该层的类型(Conv2d、AvgPool2d、Flatten、Linear)以及输出张量的形状print(layer.__class__.__name__,'output shape: \t',X.shape)# torch.Size([1, 6, 28, 28])中的1代表批次大小,6表示通道数

运行结果如下:

在这里插入图片描述


4.2,定义模型评估函数

我们已经实现了LeNet,接下来让我们看看LeNet在Fashion-MNIST数据集上的表现。

加载Fashion-MNIST图片分类数据集

batch_size = 256  # 批量大小
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

定义评估函数计算预测准确率

def evaluate_accuracy_gpu(net, data_iter, device=None):"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device: # 若没有指定device,则通过获取模型参数的第一个元素的设备来确定应该使用的设备# net.parameters()返回模型的所有可学习参数(如权重和偏置)# next() 函数从迭代器中获取第一个元素。通常是第一个层的权重或偏置# .device 是 PyTorch 张量(torch.Tensor)的一个属性,表示该张量所在的备(如 GPU 或 CPU)# 例如,模型在 GPU 上运行,.device 的值可能是 device(type='cuda', index=0)device = next(iter(net.parameters())).device# 累加器记录正确预测的数量和总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():  # 评估模型时,不需要计算梯度for X, y in data_iter: # 每次迭代获取一个数据批次X和对应的标签yif isinstance(X, list):  # x为list,每个元素都挪到对应的设备X = [x.to(device) for x in X]else:   # x是tensor,只需要挪一次X = X.to(device)y = y.to(device)# accuracy可以计算出预测正确的样本数量# y.numel()计算出样本总数metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

4.3,定义训练函数进行训练

定义可以使用GPU训练的训练函数。

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型"""def init_weights(m): # 初始化权重# 如果是全连接层或卷积层使用Xavier均匀初始化方法if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)# 模型移动到设备net.to(device)# 使用随机梯度下降(SGD)优化器optimizer = torch.optim.SGD(net.parameters(), lr=lr)# 使用交叉熵损失函数(nn.CrossEntropyLoss),适用于分类任务loss = nn.CrossEntropyLoss()# 实现动画效果打印输出animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 累加器记录训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)# 将模型设置为训练模式,这会启用Dropout等训练时特有的操作net.train()for i, (X, y) in enumerate(train_iter): # 遍历训练数据集timer.start()optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将输入数据X和标签y移动到指定的设备# 前向传播,得到预测结果 y_haty_hat = net(X) """在 PyTorch 中,nn.CrossEntropyLoss 默认会对每个样本的损失值进行平均,返回的是批次中所有样本损失的平均值。"""l = loss(y_hat, y) # 计算损失# 进行反向传播,计算梯度。l.backward()# 使用优化器更新模型参数。optimizer.step()with torch.no_grad():  # 禁用梯度计算# l * X.shape[0]是当前批次的总损失。样本平均损失乘当前批次样本数# d2l.accuracy(y_hat, y) 计算当前批次正确预测的样本数# X.shape[0]代表当前批次的样本数# 最终累加器累积了整个训练集的总损失,预测正确的样本总数和总样本数metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()# 计算整个训练集上每个训练样本的平均损失train_l = metric[0] / metric[2]# 计算训练准确率train_acc = metric[1] / metric[2]# 更新动画if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))# 在每个epoch结束时,计算测试集上的准确率test_acc = evaluate_accuracy_gpu(net, test_iter)# 更新动画animator.add(epoch + 1, (None, None, test_acc))# 打印训练损失、训练准确率和测试准确率print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')# 打印训练过程中每秒处理的样本数(即训练效率),以及训练所使用的设备。print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

调用函数进行训练

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

运行结果如下:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/76516.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录第15天:(二叉树)

一、二叉搜索树的最小绝对差(Leetcode 530) 思路1 :中序遍历将二叉树转化为有序数组,然后暴力求解。 class Solution:def __init__(self):# 初始化一个空的列表,用于保存树的节点值self.vec []def traversal(self, r…

计算机操作系统-【死锁】

文章目录 一、什么是死锁?死锁产生的原因?死锁产生的必要条件?互斥条件请求并保持不可剥夺环路等待 二、处理死锁的基本方法死锁的预防摒弃请求和保持条件摒弃不可剥夺条件摒弃环路等待条件 死锁的避免银行家算法案例 提示:以下是…

vue拓扑图组件

vue拓扑图组件 介绍技术栈功能特性快速开始安装依赖开发调试构建部署 使用示例演示截图组件源码 介绍 一个基于 Vue3 的拓扑图组件,具有以下特点: 1.基于 vue-flow 实现,提供流畅的拓扑图展示体验 2.支持传入 JSON 对象自动生成拓扑结构 3.自…

go 通过汇编分析函数传参与返回值机制

文章目录 概要一、前置知识二、汇编分析2.1、示例2.2、汇编2.2.1、 寄存器传值的汇编2.2.2、 栈内存传值的汇编 三、拓展3.1 了解go中的Duff’s Device3.2 go tool compile3.2 call 0x46dc70 & call 0x46dfda 概要 在上一篇文章中,我们研究了go函数调用时的栈布…

python-1. 找单独的数

问题描述 在一个班级中,每位同学都拿到了一张卡片,上面有一个整数。有趣的是,除了一个数字之外,所有的数字都恰好出现了两次。现在需要你帮助班长小C快速找到那个拿了独特数字卡片的同学手上的数字是什么。 要求: 设…

算法学习C++需注意的基本知识

文章目录 01_算法中C需注意的基本知识cmath头文件一些计算符ASCII码表数据类型长度运算符cout固定输出格式浮点数的比较max排序自定义类型字符的大小写转换与判断判断字符是数字还是字母 02_数据结构需要注意的内容1.stringgetline函数的使用string::findsubstr截取字符串strin…

从零开始写android 的智能指针

Android中定义了两种智能指针类型,一种是强指针sp(strong pointer),源码中的位置在system/core/include/utils/StrongPointer.h。另外一种是弱指针(weak pointer)。其实称之为强引用和弱引用更合适一些。强…

【leetcode hot 100 152】乘积最大子数组

错误解法:db[i]表示以i结尾的最大的非空连续,动态规划:dp[i] Math.max(nums[i], nums[i] * dp[i - 1]); class Solution {public int maxProduct(int[] nums) {int n nums.length;int[] dp new int[n]; // db[i]表示以i结尾的最大的非空连…

图论整理复习

回溯: 模板: void backtracking(参数) {if (终止条件) {存放结果;return;}for (选择:本层集合中元素(树中节点孩子的数量就是集合的大小)) {处理节点;backtracking(路径,选择列表); // 递归回溯&#xff…

uniapp离线打包提示未添加videoplayer模块

uniapp中使用到video标签,但是离线打包放到安卓工程中,运行到真机中时提示如下: 解决方案: 1、把media-release.aar、weex_videoplayer-release.aar放到工程的libs目录下; 文档:https://nativesupport.dcloud.net.cn/…

打包构建替换App名称

方案适用背景 一套代码出多个安装包,且安装包的应用名称、图标都不一样考虑三语名称问题 通过 Gradle 脚本实现 gradle.properties 里面定义标识来区分应用,如下文里的 APP_TYPEAAA 、APP_TYPEBBB// 定义 groovy 替换方法 def replaceAppName(String …

DrissionPage移动端自动化:从H5到原生App的跨界测试

一、移动端自动化测试的挑战与机遇 移动端测试面临多维度挑战: 设备碎片化:Android/iOS版本、屏幕分辨率差异 混合应用架构:H5页面与原生组件的深度耦合 交互复杂性:多点触控、手势操作、传感器模拟 性能监控:内存…

达梦数据库用函数实现身份证合法校验

达梦数据库用函数实现身份证合法校验 拿走不谢~ CREATE OR REPLACE FUNCTION CHECK_IDCARD(A_SFZ IN VARCHAR2) RETURN VARCHAR2 IS TYPE WEIGHT_TAB IS VARRAY(17) OF NUMBER; TYPE CHECK_TAB IS VARRAY(11) OF CHAR; WEIGHT_FACTOR WEIGHT_TAB : WEIGHT_TAB(7,9,10,5,8,4,…

3dmax的python通过普通的摄像头动捕表情

1、安装python 进入cdm,打python要能显示版本号 >>>(进入python提示符模式) import sys sys.path显示python的安装路径, 进入到python.exe的路径 在python目录中安装(ctrlz退出python交互模式) 2、pip install mediapipe…

国产Linux统信安装mysql8教程步骤

系统环境 uname -a Linux FlencherHU-PC 6.12.9-amd64-desktop-rolling #23.01.01.18 SMP PREEMPT_DYNAMIC Fri Jan 10 18:29:31 CST 2025 x86_64 GNU/Linux下载离线安装包 浏览器下载https://downloads.mysql.com/archives/get/p/23/file/mysql-test-8.0.33-linux-glibc2.28…

Vite 权限绕过导致任意文件读取(CVE-2025-32395)(附脚本)

免责申明: 本文所描述的漏洞及其复现步骤仅供网络安全研究与教育目的使用。任何人不得将本文提供的信息用于非法目的或未经授权的系统测试。作者不对任何由于使用本文信息而导致的直接或间接损害承担责任。如涉及侵权,请及时与我们联系,我们将尽快处理并删除相关内容。 前言…

poi-tl

官网地址 Poi-tl Documentationword模板引擎https://deepoove.com/poi-tl github 地址 https://github.com/Sayi/poi-tl/tree/master gitcode 加速地址 GitCode - 全球开发者的开源社区,开源代码托管平台GitCode是面向全球开发者的开源社区,包括原创博客,开源代码托管,代码…

操作系统 4.1-I/O与显示器

外设工作起来 操作系统让外设工作的基本原理和过程,具体来说,它概括了以下几个关键步骤: 发出指令:操作系统通过向控制器中的寄存器发送指令来启动外设的工作。这些指令通常是通过I/O指令(如out指令)来实现…

琥珀扫描 2.0.5.0 | 文档处理全能助手,支持扫描、文字提取及表格识别

琥珀扫描是一款功能强大的文档处理应用程序。它不仅仅支持基本的文档扫描功能,还涵盖了文字提取、证件扫描、表格识别等多种实用功能。无论是学生、职员还是教师,都能从中找到适合自己的功能。该应用支持拍照生成电子件,并能自动矫正文档边缘…

jQuery UI 小部件方法调用详解

jQuery UI 小部件方法调用详解 引言 jQuery UI 是一个基于 jQuery 的用户界面和交互库,它提供了一系列小部件,如按钮、对话框、进度条等,这些小部件极大地丰富了网页的交互性和用户体验。本文将详细介绍 jQuery UI 中小部件的方法调用,帮助开发者更好地理解和应用这些小部…