经典卷积神经网络 - LeNet

image-20231022164234916

该模型用于手写的数字识别。
LeNet模型包含了多个卷积层和池化层,以及最后的全连接层用于分类。其中,每个卷积层都包含了一个卷积操作和一个非线性激活函数,用于提取输入图像的特征。池化层则用于缩小特征图的尺寸,减少模型参数和计算量。全连接层则将特征向量映射到类别概率上。

MNISt数据集

50000个训练数据,10000个测试数据。图像大小为28x28,共10类(0~9)。

  • LeNet是早期成功的神经网络
  • 先使用卷积层来学习图片空间信息
  • 然后使用全连接层来转换到类别空间

对于padding

通用的卷积时padding 的选择

如卷积核宽高为3时 padding 选择1

如卷积核宽高为5时 padding 选择2

如卷积核宽高为7时padding选择3

至于选择填充多少像素,通常有两个选择,分别叫做Valid卷积和Same卷积。

Valid卷积意味着不填充,这样的话,如果你有一个 n × n n\times n n×n的图像,用一个 f × f f\times f f×f的过滤器卷积,它将会给你一个 ( n − f + 1 ) × ( n − f + 1 ) (n-f+1)\times (n-f+1) (nf+1)×(nf+1)维的输出。例如,有一个6×6的图像,通过一个3×3的过滤器,得到一个4×4的输出。

Same卷积意味你填充后,你的输出大小和输入大小是一样的。根据这个公式 n − f + 1 n-f+1 nf+1,当你填充 p p p个像素点,n就变成了 n + 2 p n+2p n+2p,最后公式变为 n + 2 p − f + 1 n+2p-f+1 n+2pf+1。因此如果你有一个 n × n n\times n n×n的图像,用 p p p个像素填充边缘,输出的大小就是这样的 ( n + 2 p − f + 1 ) × ( n + 2 p − f + 1 ) (n+2p−f+1)\times (n+2p−f+1) (n+2pf+1)×(n+2pf+1)。如果你想让 ( n + 2 p − f + 1 ) = n (n+2p−f+1)=n (n+2pf+1)=n的话,使得输出和输入大小相等,如果你用这个等式求解 p p p,那么 p = ( f − 1 ) / 2 p=(f-1)/2 p=(f1)/2。所以当 f f f是一个奇数的时候,只要选择相应的填充尺寸,你就能确保得到和输入相同尺寸的输出。

代码实现

LeNet(LeNet-5)由两个部分组成:卷积编码器和全连接层密集块。

model.py

from torch import nnclass Reshape(nn.Module):def forward(self,x):return x.reshape((-1,1,28,28))class MyLeNet(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)# 假如sigmoid激活函数后 损失不下降self.model = nn.Sequential(Reshape(),nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),nn.AvgPool2d(2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),nn.AvgPool2d(2,stride=2),nn.Flatten(),nn.Linear(16*5*5,120),nn.Linear(120,84),nn.Linear(84,10))def forward(self,x):return self.model(x)

train.py

# 扫描数据次数
epochs = 20
# 分组大小
batch = 64
# 学习率
learning_rate = 0.05
# 训练次数
train_step = 0
# 测试次数
test_step = 0# 定义图像转换
transform = transforms.Compose([transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.MNIST(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.MNIST(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")# 创建网络
net = MyLeNet()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):print("-------------------第 {} 轮训练开始-------------------".format(epoch))net.train()for data in train_dataloader:train_step = train_step + 1images,targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs,targets)optimizer.zero_grad()loss_out.backward()optimizer.step()if train_step%100==0:writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))# 测试net.eval()total_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:test_step = test_step + 1images, targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs, targets)total_loss = total_loss + loss_outaccuracy = (targets == torch.argmax(outputs,dim=1)).sum()total_accuracy = total_accuracy + accuracy# 计算精确率print(total_accuracy)accuracy_rate = total_accuracy / test_sizeprint("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)torch.save(net,"./model/net_{}.pth".format(epoch+1))print("模型net_{}.pth已保存".format(epoch+1))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/115524.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

免费领取!TikTok Shop “全托管”黑五大促官方备战指南来啦!

黑五网一大促即将来袭,自“全托管”模式上线以来,TikTok for Business在沙特阿拉伯和英国市场开展了古尔邦节大促、夏季大促、返校季大促等活动,今年更是会借着黑五网一大促之际,首次覆盖美国市场,为全托管商家带来全球…

如何在Potplayer中使用公网访问群晖WebDav?

文章目录 1 使用环境要求:2 配置webdav3 测试局域网使用potplayer访问webdav4 内网穿透,映射至公网5 使用固定地址在potplayer访问webdav ​ 国内流媒体平台的内容让人一言难尽,就算是购买了国外的优秀作品,也总是在关键剧情上删删…

系统性认知网络安全

前言:本文旨在介绍网络安全相关基础知识体系和框架 目录 一.信息安全概述 信息安全研究内容及关系 信息安全的基本要求 保密性Confidentiality: 完整性Integrity: 可用性Availability: 二.信息安全的发展 20世纪60年代&…

学生成绩管理神器

老师们是否还在为繁琐的成绩查询而烦恼?是否希望有一个简便易用的成绩查询系统,让同学们可以自助查询成绩?那么,这篇文章就是你的救星! 让我们一起来认识一下这个成绩查询系统。它是一个基于网页或微信小程序的应用程序…

数据结构实验1约瑟夫环

刚开始m值为20 循环链表 #include<iostream>using namespace std;typedef struct LNode {int data; int num;struct LNode *next;}LNode ,*LinkList; int m 20; int n;void Init(LinkList &L) {cin>>n;LinkList p L;p ->data n;for(int i 0 ; i < n…

Spring底层原理(一)

Spring底层原理&#xff08;一&#xff09; ApplitionContext与BeanFactory BeanFactory是ApplicationContext的父接口BeanFactory才是Spring的核心容器,ApplicationContext对其功能进行了组合 类图 内部方法调用 BeanFactory的功能 获取bean检查是否包含bean获取bean别名 …

220V降压5V用什么方案比较好?

对于将220V降压到5V的方案&#xff0c;根据输出电流大小&#xff0c;有两种选择&#xff1a;AH8652和AH8699B。 AH8652是一个sot23-3封装的芯片&#xff0c;固定输出5V&#xff0c;峰值电流为200mA&#xff0c;并内置了MOS管。这个芯片适合需要固定输出电压的应用场景&#xf…

设计模式:模板模式(C#、JAVA、JavaScript、C++、Python、Go、PHP)

简介&#xff1a; 模板模式&#xff0c;它是一种行为型设计模式&#xff0c;它定义了一个操作中的算法的框架&#xff0c;将一些步骤延迟到子类中实现&#xff0c;使得子类可以不改变一个算法的结构即可重定义该算法的某些特定步骤。 通俗地说&#xff0c;模板模式就是将某一行…

✔ ★【备战实习(面经+项目+算法)】 10.22学习时间表(算法刷题:4道)

✔ ★【备战实习&#xff08;面经项目算法&#xff09;】 坚持完成每天必做如何找到好工作1. 科学的学习方法&#xff08;专注&#xff01;效率&#xff01;记忆&#xff01;心流&#xff01;&#xff09;2. 每天认真完成必做项&#xff0c;踏实学习技术 认真完成每天必做&…

简单易用的操作界面,让你轻松制作电子期刊

随着互联网的发展&#xff0c;电子期刊已经成为了越来越多人的选择。FLBOOK在线制作电子杂志平台作为一款简单易用的操作界面&#xff0c;为用户提供了制作电子期刊的便利。 但是你知道如何使用FLBOOK在线制作电子杂志平台制作一本电子期刊吗&#xff1f; 1.点击开始创作&#…

Spring boot 集成 xxl-job

文章目录 xxl-job 简介引入xxl-job依赖配置xxl-job config添加properties文件配置BEAN模式&#xff08;方法形式&#xff09;步骤一&#xff1a;执行器项目中&#xff0c;开发Job方法&#xff1a;步骤二&#xff1a;调度中心&#xff0c;新建调度任务 xxl-job 简介 官网:https:…

k8s----11、service

services 1、概述2、存在的意义2.1 服务发现2.2 负载均衡 3、pod与service的关系4、service 三种类型4.1 、 ClusterIP4.2 、NodePort4.3 、LoadBalancer 1、概述 Service 是 Kubernetes 最核心概念&#xff0c;通过创建 Service,可以为一组具有相同功能的容器应 用提供一个统…

YOLOV8目标检测——最全最完整模型训练过程记录

文章目录 前言1 下载yolov8&#xff08;[网址](https://github.com/ultralytics/ultralytics)&#xff09;2 配置conda环境3 用pycharm打开文件3 训练自己的YOLOV8数据集4 run下运行完了之后没有best.pt文件5 导出为onnx文件6 yolov8应用完整案例&#xff08;免费且包含源代码、…

iOS上架App Store的全攻略

​ 第一步&#xff1a;申请开发者账号 在开始将应用上架到App Store之前&#xff0c;你需要申请一个开发者账号。 1.1 打开苹果开发者中心网站&#xff1a;Apple Developer 1.2 使用Apple ID和密码登录&#xff08;如果没有账号则需要注册&#xff09;&#xff0c;要确保使用…

Oracle的dbms.rls实现数据访问控制

在大部份系统中&#xff0c;权限控制主要定义为模块进入权限的控制和数据列访问权限的控制(如&#xff1a;某某人可以进入某个控制&#xff0c;仓库不充许查看有关部门的字段等等)。 但在某些系统中&#xff0c;权限控制又必须定义到数据行访问权限的控制&#xff0c;此需求一般…

1920: 【C2】【字符串】调换位置

目录 题目描述 输入 输出 样例输入 样例输出 题目描述 将用逗号隔开的两个英语单词交换位置输出 输入 一行以逗号隔开的两个单词 输出 将两个单词交换后输出 样例输入 abc,de 样例输出 de,abc 4 33 104 87 16C: #include<bits/stdc.h> using namespace st…

【Mysql】Mysql的数据目录

前言 我们知道InnoDB 、 MyISAM 存储引擎都是把表存储在文件系统上的。当我们想读取数据的时候&#xff0c;这些存储引擎会从文件系统中把数据读出来返回给我们&#xff0c;当我们想写入数据的时候&#xff0c;这些存储引擎会把这些数据又写回文件系统。那么这两个存储引擎的数…

基于Python实现的快速的仿手写文字的图片生成器项目源码

Quick Hand &#x1f4dd; 介绍 快速的仿手写文字的图片生成器。 完整代码下载地址&#xff1a;基于Python实现的快速的仿手写文字的图片生成器 界面预览&#xff1a; &#x1f52e; 使用说明 原理&#xff1a;首先&#xff0c;在水平位置、竖直位置和字体大小三个自由度上…

uniapp开发微信小程序,webview内嵌h5,h5打开pdf地址,解决方案

根据公司要求&#xff0c;让我写一个h5&#xff0c;后续会嵌入到合作公司的微信小程序的webview中&#xff0c;如果是自己公司微信小程序&#xff0c;可以采取先下载下来pdf&#xff0c;然后通过wx.openDocument&#xff0c;进行单纯的预览操作&#xff0c;这个可以根据这个老哥…

C语言里的static变量其他语言是看不上还是学不去?

C语言里的static变量其他语言是看不上还是学不去? static变量在C语言中被用于具有静态存储期的局部变量或全局变量。它有以下几个特点&#xff1a; 1. 静态存储期&#xff1a;static变量在程序执行时分配内存&#xff0c;直到程序结束才会释放&#xff0c;其生命周期与程序的…