动手学深度学习—卷积神经网络LeNet(代码详解)

1. LeNet

LeNet由两个部分组成:

  • 卷积编码器:由两个卷积层组成;
  • 全连接层密集块:由三个全连接层组成。

在这里插入图片描述

  1. 每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层;
  2. 每个卷积层使用5×5卷积核和一个sigmoid激活函数;
  3. 这些层将输入映射到多个二维特征输出,通常同时增加通道的数量;
  4. 每个4×4池操作(步幅2)通过空间下采样将维数减少4倍。
import torch
from torch import nn
from d2l import torch as d2l# 定义模型net
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

该模型去掉了最后一层的高斯激活,下面将一个大小为28×28的单通道(黑白)图像通过LeNet,打印每一层输出的形状。

# 观察各层的输入输出通道数,宽度和高度
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)

在这里插入图片描述

  1. 第一个卷积层使用2个像素的填充,来补偿5×5卷积核导致的特征减少;
  2. 第二个卷积层没有填充,因此高度和宽度都减少了4个像素;
  3. 随着层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个;
  4. 每个汇聚层的高度和宽度都减半;
  5. 每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。

2. 模型训练

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
"""定义精度评估函数:1、将数据集复制到显存中2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save# 判断net是否属于torch.nn.Module类if isinstance(net, nn.Module):net.eval()# 如果不在参数选定的设备,将其传输到设备中if not device:device = next(iter(net.parameters())).device# Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:# 将X, y复制到设备中if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)# 计算正确预测的数量,总预测的数量,并存储到metric中metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]
"""定义GPU训练函数:1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;2、使用Xavier随机初始化模型参数;3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""# 定义初始化参数,对线性层和卷积层生效def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)# 在设备device上进行训练print('training on', device)net.to(device)# 优化器:随机梯度下降optimizer = torch.optim.SGD(net.parameters(), lr=lr)# 损失函数:交叉熵损失函数loss = nn.CrossEntropyLoss()# Animator为绘图函数animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])# 调用Timer函数统计时间timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量metric = d2l.Accumulator(3)net.train()# enumerate() 函数用于将一个可遍历的数据对象for i, (X, y) in enumerate(train_iter):timer.start() # 进行计时optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将特征和标签转移到devicey_hat = net(X)l = loss(y_hat, y) # 交叉熵损失l.backward() # 进行梯度传递返回optimizer.step()with torch.no_grad():# 统计损失、预测正确数和样本数metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop() # 计时结束train_l = metric[0] / metric[2] # 计算损失train_acc = metric[1] / metric[2] # 计算精度# 进行绘图if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))# 测试精度test_acc = evaluate_accuracy_gpu(net, test_iter) animator.add(epoch + 1, (None, None, test_acc))# 输出损失值、训练精度、测试精度print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'f'test acc {test_acc:.3f}')# 设备的计算能力print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'f'on {str(device)}')

在这里插入图片描述

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

3. 小结

  1. 卷积神经网络(CNN)是一类使用卷积层的网络;
  2. 卷积神经网络中,可以组合使用卷积层、非线性激活函数和汇聚层;
  3. 为了构造高性能的卷积神经网络,通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数;
  4. 在传统的卷积神经网络中,卷积块编码得到的表征在输出之前需由一个或多个全连接层进行处理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/43110.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode--HOT100题(35)

目录 题目描述:23. 合并 K 个升序链表(困难)题目接口解题思路1代码解题思路2代码 PS: 题目描述:23. 合并 K 个升序链表(困难) 给你一个链表数组,每个链表都已经按升序排列。 请你将所有链表合…

UDP 的报文结构以及注意事项

UDP协议 1.UDP协议端格式 1.图中的16位UDP长度,表示整个数据报(UDP首部UDP数据)的最大长度 2.若校验和出错,会直接丢弃 2.UDP的报文结构 UDP报文主体分为两个部分:UDP报头(占8个字节)UDP载荷/UDP数据 1.源端口号 16位,2个字节 2.目的端口号 16位,2个字节 3.包长度 指示了…

sd-webui安装comfyui扩展

文章目录 导读ComfyUI 环境安装1. 安装相关组件2. 启动sd-webui3. 访问sd-webui 错误信息以及解决办法 导读 这篇文章主要给大家介绍如何在sd-webui中来安装ComfyUI插件 ComfyUI ComfyUI是一个基于节点流程式的stable diffusion的绘图工具,它集成了stable diffus…

什么是闭包(closure)?为什么它在JavaScript中很有用?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 闭包(Closure)是什么?⭐ 闭包的用处⭐ 写在最后 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅&…

windows10 安装WSL2, Ubuntu,docker

AI- 通过docker开发调试部署ChatLLM 阅读时长:10分钟 本文内容: window上安装ubuntu虚拟机,并在虚拟机中安装docker,通过docker部署数字人模型,通过vscode链接到虚拟机进行开发调试.调试完成后,直接部署在云…

变更通知在开源SpringBoot/SpringCloud微服务中的最佳实践

目录导读 变更通知在开源SpringBoot/SpringCloud微服务中的最佳实践1. 什么是变更通知2. 变更通知的场景分析3. 变更通知的技术方案3.1 变更通知的技术实现方案 4. 变更通知的最佳实践总结5. 参考资料 变更通知在开源SpringBoot/SpringCloud微服务中的最佳实践 1. 什么是变更通…

Ubuntu在自己的项目中使用pcl

1、建立一个文件夹,如pcl_demos,里面建立一个.cpp文件和一个cmake文件 2、打开终端并进入该文件夹下,建立一个build文件夹存放编译的结果并进入该文件夹 3、对上一级进行编译 cmake .. 4、生成可执行文件 make 5、运行该可执行文件 6、可视…

微服务中间件-分布式缓存Redis

分布式缓存 a.Redis持久化1) RDB持久化1.a) RDB持久化-原理 2) AOF持久化3) 两者对比 b.Redis主从1) 搭建主从架构2) 数据同步原理(全量同步)3) 数据同步原理(增量同步) c.Redis哨兵1) 哨兵的作用2) 搭建Redis哨兵集群3) RedisTem…

金融语言模型:FinGPT

项目简介 FinGPT是一个开源的金融语言模型(LLMs),由FinNLP项目提供。这个项目让对金融领域的自然语言处理(NLP)感兴趣的人们有了一个可以自由尝试的平台,并提供了一个与专有模型相比更容易获取的金融数据。…

VS2015项目中,MFC内存中调用DLL函数(VC6生成的示例DLL)

本例主要讲一下,用VC6如何生成DLL,用工具WinHex取得DLL全部内容,VC2015项目加载内存中的DLL函数,并调用函数的示例。 本例中的示例代码下载,点击可以下载 一、VC6.0生成示例DLL项目 1.新建项目,…

SQL Server Express 自动备份方案

文章目录 SQL Server Express 自动备份方案前言方案原理SQL Server Express 自动备份1.创建存储过程2.设定计划任务3.结果检查sqlcmd 参数说明SQL Server Express 自动备份方案 前言 对于许多小型企业和个人开发者来说,SQL Server Express是一个经济实惠且强大的数据库解决方…

Spring Framework中的Bean生命周期

目录 一.Bean生命周期的简介 1.基本概念 2.Spring生命周期的几大阶段 3.注意点及小结 4.生活案例 5.Spring容器管理JavaBean的初始化过程 二. Bean的单例选择与多例选择 1.单例选择与多例选择的优缺点 1.1单例模式的优点: 1.2单例模式的缺点: 1…

JDK 8 升级 JDK 17 全流程教学指南

JDK 8 升级 JDK 17 首先已有项目升级是会经历一个较长的调试和自测过程来保证允许和兼容没有问题。先说几个重要的点 遇到问题别放弃仔细阅读报错,精确到每个单词每一行,不是自己项目的代码也要点进去看看源码到底是为啥报错明确你项目引入的包&#x…

第三届“赣政杯”网络安全大赛 | 赛宁筑牢安全应急防线

​​为持续强化江西省党政机关网络安全风险防范意识,提高信息化岗位从业人员基础技能,提升应对网络安全风险处置能力。由江西省委网信办、江西省发展改革委主办,江西省大数据中心、国家计算机网络与信息安全管理中心江西分中心承办&#xff0…

Qt扫盲-QTableView理论总结

QTableView理论总结 一、概述二、导航三、视觉外观四、坐标系统五、示例代码1. 性别代理2. 学生信息模型3. 对应视图 一、概述 QTableView实现了一个tableview 来显示model 中的元素。这个类用于提供之前由QTable类提供的标准表,但这个是使用Qt的model/view架构提供…

MySQL 存储过程

create procedure 存储过程名 (in | out | INPUT 参数名 参数类型,。。。) 【characteristics 。。。】begin存储过程体end存储过程的参数类型 IN 、OUT、INPUT 都可以在一个存储过程带多个 没有参数(无参数无返回)仅…

边缘网络的作用及管理工具

自从引入软件即服务 (SaaS) 以来,它一直引领着全球按需软件部署创新的竞赛,它提供的灵活性以及其云计算架构带来的易于集成使其成为交付业务应用程序的标准。 在 SaaS 模型中,最佳用户体验的三重奏涉及无缝设置、低延…

JMeter 特殊组件-逻辑控制器与BeanShell PreProcessor 使用示例

文章目录 前言JMeter 特殊组件-逻辑控制器与BeanShell PreProcessor 使用示例1. 逻辑控制器使用1.1. While Controller 使用示例1.2. 如果(If)控制器 使用示例 2. BeanShell PreProcessor 使用示例 前言 如果您觉得有用的话,记得给博主点个赞…

Java课题笔记~ SpringBoot简介

1. 入门案例 问题导入 SpringMVC的HelloWord程序大家还记得吗? SpringBoot是由Pivotal团队提供的全新框架,其设计目的是用来简化Spring应用的初始搭建以及开发过程 原生开发SpringMVC程序过程 1.1 入门案例开发步骤 ①:创建新模块&#…

设计模式-过滤器模式(使用案例)

过滤器模式(Filter Pattern)或标准模式(Criteria Pattern)是一种设计模式,这种模式允许开发人员使用不同的标准来过滤一组对象,通过逻辑运算以解耦的方式把它们连接起来。这种类型的设计模式属于结构型模式…