Pytorch实战(一):LeNet神经网络

文章目录

  • 一、模型实现
    • 1.1数据集的下载
    • 1.2加载数据集
    • 1.3模型训练
    • 1.4模型预测


  LeNet神经网络是第一个卷积神经网络(CNN),首次采用了卷积层、池化层这两个全新的神经网络组件,接收灰度图像,并输出其中包含的手写数字,在手写字符识别任务上取得了瞩目的准确率。LeNet网络的一系列的版本,以LeNet-5版本最为著名,也是LeNet系列中效果最佳的版本。LeNet神经网络输入图像大小必须为32x32,且所用卷积核大小固定为5x5,模型结构如下:
在这里插入图片描述

模型参数:

  • INPUT(输入层):输入图像尺寸为32x32,且是单通道灰色图像。
  • C1(卷积层):使用6个5x5大小的卷积核,步长为1,卷积后得到6张28×28的特征图。
  • S2(池化层):使用了6个2×2 的平均池化,池化后得到6张14×14的特征图。
  • C3(卷积层):使用了16个大小为5×5的卷积核,步长为1,得到 16 张10×10的特征图。
  • S4(池化层):使用16个2×2的平均池化,池化后得到16张5×5 的特征图。
  • C5(卷积层):使用120个大小为5×5的卷积核,步长为1,卷积后得到120张1×1的特征图。
  • F6(全连接层):输入维度120,输出维度是84(对应7x12 的比特图)。
  • OUTPUT(输出层):使用高斯核函数,输入维度84,输出维度是10(对应数字 0 到 9)。

该模型有如下特点:

  • 1.首次提出卷积神经网络基本框架: 卷积层,池化层,全连接层。
  • 2.卷积层的权重共享,相较于全连接层使用更少参数,节省了计算量与内存空间。
  • 3.卷积层的局部连接,保证图像的空间相关性。
  • 4.使用映射到空间均值下采样,减少特征数量。
  • 5.使用双曲线(tanh)或S型(sigmoid)形式的非线性激活函数。

一、模型实现

1.1数据集的下载

  使用torchversion内置的MNIST数据集,训练集大小60000,测试集大小10000,图像大小是1×28×28,包括数字0~9共10个类。

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torchvision
# 下载训练、测试数据集
mnist_train = torchvision.datasets.MNIST(root='./dataset/',train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='./dataset/',train=False, download=True, transform=transforms.ToTensor())
print('mnist_train基本信息为:',mnist_train)
print('-----------------------------------------')
print('mnist_test基本信息为:',mnist_test)
print('-----------------------------------------')
img,label=mnist_train[0]
print('mnist_train[0]图像大小及标签为:',img.shape,label)

在这里插入图片描述

1.2加载数据集

trainDataLoader = DataLoader(mnist_train, batch_size=64, num_workers=5, shuffle=True)
testDataLoader = DataLoader(mnist_test, batch_size=64, num_workers=0, shuffle=True)
write = SummaryWriter('./log')
step = 0
for images, labels in testDataLoader:write.add_images(tag='train', images, global_step=step)step += 1
write.close()

  注意不能使用for images, labels in testDataLoader.datasettestDataLoader.dataset[0]是保存图像(28
,28)和对应标签的元组,而Tensorboardadd_images只能输入NCHW格式对象,使用该代码会报错:

size of input tensor and input format are different. tensor shape: (1, 28, 28), input_format: NCHW

数据加载器按batch_size对数据及标签进行封装名,可直接作为输入。查看封装的元组:

for data in testDataLoader:print('type(data):',type(data))img,label=dataprint('type(img):',type(img),'img.shape:',img.shape)print('type(label):',type(label),'label.shape:',label.shape)

在这里插入图片描述

1.3模型训练

  LeNet模型的输入为(32,32)的图片,而MNIST数据集为(28,28)的图片,故需对原图片进行填充。搭建模型:

class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.model = nn.Sequential(  #MNIST数据集图像大小为28x28,而LeNet输入为32x32,故需填充nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2),  #C1层共六个卷积核,故out_channels=6nn.AvgPool2d(kernel_size=2, stride=2),  #C2层使用平均池化nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Conv2d(in_channels=16 * 5 * 5, out_channels=120),nn.Linear(in_features=120, out_features=84),nn.Linear(in_features=84, out_features=10))def forward(self, x):return self.model(x)# 初始化模型对象
myLeNet = LeNet()

  设置损失函数、优化器并训练模型:

# 设置损失函数为交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
# 设置优化器,使用Adam优化算法
learning_rate = 1e-2
optimizer = torch.optim.Adam(myLeNet.parameters(), lr=learning_rate)
total_train_step = 0  # 总训练次数
epoch = 10  # 训练轮数
writer = SummaryWriter(log_dir='./runs/LeNet/')
for i in range(epoch):print("-----第{}轮训练开始-----".format(i + 1))myLeNet.train()  # 训练模式train_loss = 0for data in trainDataLoader:imgs, labels = dataimgs = imgs.to(device)  # 适配GPU/CPUlabels = labels.to(device)outputs = myLeNet(imgs)loss = loss_fn(outputs, labels)#计算损失函数optimizer.zero_grad()  # 清空之前梯度loss.backward()  # 反向传播optimizer.step()  # 更新参数total_train_step += 1  # 更新步数train_loss += loss.item()writer.add_scalar("train_loss_detail", loss.item(), total_train_step)writer.add_scalar("train_loss_total", train_loss, i + 1)writer.close()

1.4模型预测

myLeNet.eval() 
total_test_loss = 0  # 当前轮次模型测试所得损失
total_accuracy = 0  # 当前轮次精确率
with torch.no_grad():  # 关闭梯度反向传播for data in testDataLoader:imgs, targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = myLeNet(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracy
writer.add_scalar("test_loss", total_test_loss, i+1)
writer.add_scalar("test_accuracy", total_accuracy/len(mnist_test), i+1)

https://blog.csdn.net/qq_43307074/article/details/126022041?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171938503416800186515588%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=171938503416800186515588&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_click~default-2-126022041-null-null.142v100pc_search_result_base3&utm_term=LeNet&spm=1018.2226.3001.4187

https://blog.csdn.net/hellocsz/article/details/80764804?ops_request_misc=&request_id=&biz_id=102&utm_term=LeNet&utm_medium=distribute.pc_search_result.none-task-blog-2allsobaiduweb~default-1-80764804.142v100pc_search_result_base3&spm=1018.2226.3001.4187

https://blog.csdn.net/qq_45034708/article/details/128319241?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171936257316800222847105%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=171936257316800222847105&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_positive~default-1-128319241-null-null.142v100pc_search_result_base3&utm_term=LeNet&spm=1018.2226.3001.4187

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/35039.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

告别模糊时代,扫描全能王带来清晰世界

模糊碑文引发的思考 上个月中旬去洛阳拜访了著名的龙门石窟,本就对碑文和文字图画感兴趣的我们,准备好好欣赏一下龙门石窟的历史文化古迹。到了地方之后,我发现石窟的高度和宽度远远超出了想象,正因如此,拍出来的文字…

NewspaceGPT带你玩系列之美人鱼图表

这里写目录标题 注册一个账号,用qq邮箱,然后登录选一个可用的Plus,不要选3.5探索GPT今天的主角是开始寻梦美人鱼图表我选第一个试一下问:重新回答上面的问题,一切都用汉语重新生成一个流程图:生成一个网站登…

OpenAI“跌倒”,国产大模型“吃饱”?

大数据产业创新服务媒体 ——聚焦数据 改变商业 在AI的世界里,OpenAI就像是一位高高在上的霸主,它的一举一动,都能引发行业里的地震。然而,就在不久前,这位霸主突然宣布了一个决定,自7月9日起,…

2024热门骨传导蓝牙耳机怎么选?超全的选购攻略附带好物推荐!

对于很多喜欢运动健身的小伙伴,在现在市面上这么多种类耳机的选择上,对于我来说的话还是很推荐大家去选择骨传导运动耳机的,相较于普通的入耳式蓝牙耳机,骨传导耳机是通过振动来传输声音的,而入耳式耳机则是通过空气传…

以Bert训练为例,测试torch不同的运行方式,并用torch.profile+HolisticTraceAnalysis分析性能瓶颈

以Bert训练为例,测试torch不同的运行方式,并用torch.profileHolisticTraceAnalysis分析性能瓶颈 1.参考链接:2.性能对比3.相关依赖或命令4.测试代码5.HolisticTraceAnalysis代码6.可视化A.优化前B.优化后 以Bert训练为例,测试torch不同的运行方式,并用torch.profileHolisticTra…

正则表达式阅读理解

这段正则表达式可以匹配什么呢? ((max|min)\\s*\\([^\\)]*(,[^\\)]*)*\\)|[a-zA-Z][a-zA-Z0-9]*(_[a-zA-Z][a-zA-Z0-9]*)?(\\*||%)?|[0-9](\\.[0-9])?|\\([^\\)]*(,[^\\)]*)*\\))(\\s*[-*/%]\\s*([a-zA-Z][a-zA-Z0-9]*(_[a-zA-Z][a-zA-Z0-9]*)?(\\*||%)?|[0-…

Charls数据库+预测模型发二区top | CHARLS等七大老年公共数据库周报(6.19)

七大老年公共数据库 七大老年公共数据库共涵盖33个国家的数据,包括:美国健康与退休研究 (Health and Retirement Study, HRS);英国老龄化纵向研究 (English Longitudinal Study of Ageing, ELSA);欧洲健康、…

HashMap第5讲——resize方法扩容源码分析及细节

put方法的源码和相关的细节已经介绍完了,下面我们进入扩容功能的讲解。 一、为什么需要扩容 这个也比较好理解。假设现在HashMap里的元素已经很多了,但是链化比较严重,即便树化了,查询效率也是O(logN),肯定没有O(1)好…

IDEA注释快只有一行时不分行的设置

在编写注释时,有时使用注释块来标注一个变量或者一段代码时,为了节约空间,希望只在一行中显示注释快。只需要按照下图将“一行注释不分行”勾选上即可。

M Farm RPG Assets Pack(农场RPG资源包)

🌟塞尔达的开场动画:风鱼之歌风格!🌟 像素参考:20*20 字体和声音不包括在内 资产包括: 1名身体部位分离的玩家和4个方向动画: 闲逛|散步|跑步|持有物品|使用工具|拉起|浇水 6个带有4个方向动画的工具 斧头|镐|喙|锄头|水壶|篮子 4个NPC,有4个方向动画: 闲逛|散步 �…

LSH算法:高效相似性搜索的原理与Python实现II

局部敏感哈希(LSH)是一种高效的近似相似性搜索技术,广泛应用于需要处理大规模数据集的场景。在当今数据驱动的世界中,高效的相似性搜索算法对于维持业务运营至关重要,它们是许多顶尖公司技术堆栈的核心。 相似性搜索面…

去掉window11设备和驱动器中的百度网盘图标

背景 window系统设备驱动器中显示百度网盘图标,个人强迫症,要去掉!!! 去掉window11->设备和驱动器->百度网盘 的图标 登录百度网盘点击”同步“ 点击设置 在基本设置里面去掉勾选“在我的电脑中显示百度网盘…

麒麟桌面操作系统上使用命令行添加软件图标到任务栏

原文链接:麒麟桌面操作系统上使用命令行添加软件图标到任务栏 Hello,大家好啊!今天给大家带来一篇在麒麟桌面操作系统上使用命令行添加软件图标到任务栏的文章。通过命令行添加软件图标到任务栏,可以快速、便捷地将常用的软件固定…

当大模型开始「考上」一本

参加 2024 河南高考,豆包和文心 4.0 过了一本线,但比 GPT-4o 还差点。 今天的大模型,智力水平到底如何? 2024 年高考陆续出分,我们想要解开这个过去一年普罗大众一直争论不休的话题。高考是衡量人类智力和学识水平的…

聚力教研共成长!思腾合力携手昇腾AI打造人工智能云平台

高校作为科研和创新的前沿阵地,不断推动科学技术的发展与进步。多元化的学科背景和丰富的科研课题使高校在科研创新中具有独特的竞争力,能够引领科技的发展和进步。人工智能技术快速迭代,高校在人才培养上往往偏重于理论知识的传授&#xff0…

如何获取阿里云盘的 token

方法一、通过 alist 便携获取 Token 一、访问:阿里云盘/分享 | AList文档 二、找到 刷新令牌 ,点击 获取Token,并通过阿里云APP扫码登录后获取,取到之后将 Token 粘贴至软件内 方法二、通过 网页登录 自行获取 token 我这里用的…

Sora:探索AI视频模型的无限可能

随着人工智能技术的飞速发展,AI在视频处理和生成领域的应用正变得越来越广泛。Sora,作为新一代AI视频模型,展示了前所未有的潜力和创新能力。本文将深入探讨Sora的功能、应用场景以及它所带来的革命性变化。 一、Sora的核心功能 1.1 视频生…

Pandas中的数据转换[细节]

今天我们看一下Pandas中的数据转换,话不多说直接开始🎇 目录 一、⭐️apply函数应用 apply是一个自由度很高的函数 对于Series,它可以迭代每一列的值操作: 二、⭐️矢量化字符串 为什么要用str属性 替换和分割 提取子串 …

three.js基础环境搭建

three.js three.js介绍安装threejs文件资源目录介绍本地静态服务器vscode配置live-server插件nodejs配置本地静态服务器项目的开发环境引入threejs 基础知识右手坐标系程序结构 three.js介绍 three.js官网 Three.js是一款基于WebGL的JavaScript 3D库,它使得开发者能…

go语言day2 配置

使用cmd 中的 go install ; go build 命令出现 go cannot find main module 错误怎么解决? go学习-问题记录(开发环境)go: cannot find main module; see ‘go help modules‘_go: no flags specified (see go help mod edit)-CSDN博客 在本…