CNN——LeNet

1.LeNet概述       

         LeNet是Yann LeCun于1988年提出的用于手写体数字识别的网络结构,它是最早发布的卷积神经网络之一,可以说LeNet是深度CNN网络的基石。

        当时,LeNet取得了与支持向量机(support vector machines)性能相媲美的成果,成为监督学习的主流方法。 LeNet当时被广泛用于自动取款机(ATM)机中,帮助识别处理支票的数字。

        下面是整个网络的结构图

        LeNet共有8层,其中包括输入层,3个卷积层,2个子采样层(也就是现在的池化层),1个全连接层和1个高斯连接层。

        上图中用C代表卷积层,用S代表采样层,用F代表全连接层。输入size固定在1*32*32,LeNet图片的输入是二值图像。网络的输出为0~9十个数字的RBF度量,可以理解为输入图像属于0~9数字的可能性大小。

2.详解LeNet

下面对图中每一层做详细的介绍:

  • LeNet使用的卷积核大小都为5*5,步长为1,无填充,只是卷积深度不一样(卷积核个数导致生成的特征图的通道数)
  • 激活函数为Sigmoid
  • 下采样层都是使用最大池化实现,池化的核都为2*2,步长为2,无填充

        input输入层,尺寸为1*32*32的二值图

        C1层是一个卷积层。该层使用6个卷积核,生成特征图尺寸为32-5+1=28,输出为6个大小为28*28的特征图。再经过一个Sigmoid激活函数非线性变换。

        S2层是一个下采样层。生成特征图尺寸为28/2=14,得到6个14*14的特征图。

        C3层是一个卷积层,该层使用16个卷积核,生成特征图尺寸为14-5+1=10,输出为16个10*10的特征图。再经过一个Sigmoid激活函数非线性变换。

        S4层是一个下采样层,生成特征图尺寸为10/2=5,得到16个5*5的特征图

        C5层是一个卷积层,卷积核数量增加至120。生成特征图尺寸为5-5+1=1。得到120个1*1的特征图。这里实际上相当于S4全连接了,但仍将其标为卷积层,原因是如果LeNet-5的输入图片尺寸变大,其他保持不变,那该层特征图的维数也会大于1*1,那就不是全连接了。再经过一个Sigmoid激活函数非线性变换。

        F6层是一个全连接层,该层与C5层全连接,输出84张特征图。再经过一个Sigmoid激活函数非线性变换。

        输出层:输出层由欧式径向基函数(高斯)单元组成,每个类别(0~9数字)对应一个径向基函数单元,每个单元有84个输入。也就是说,每个输出RBF单元计算输入向量和该类别标记向量之间的欧式距离,距离越远,PRF输出越大,同时我们也会将与标记向量欧式距离最近的类别作为数字识别的输出结果。当然现在通常使用的Softmax实现

3.使用LeNet实现Mnist数据集分类

1.导入所需库

import torch
import torch.nn as nn
from torchsummary import summary
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm # 显示训练进度条

2.使用GPU

device = 'cuda' if torch.cuda.is_available() else 'cpu'

3.读取Mnist数据集

# 定义数据转换以进行数据标准化
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为 PyTorch 张量
])# 下载并加载 MNIST 训练和测试数据集
train_dataset = datasets.MNIST(root='./dataset', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./dataset', train=False, download=True, transform=transform)# 创建数据加载器以批量加载数据
batch_size = 256
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

4.搭建LeNet

        需要注意的是torch.nn.CrossEntropyLoss自带了softmax函数,所以最后一层使用全连接即可,在训练时使用torch.nn.CrossEntropyLoss

class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2) # Mnist尺寸为28*28,这里设置填充变成32*32self.sigmoid = nn.Sigmoid()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(6, 16, kernel_size=5)self.flatten = nn.Flatten()self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(self.sigmoid(self.conv1(x)))x = self.pool(self.sigmoid(self.conv2(x)))x = self.flatten(x)x = self.sigmoid(self.fc1(x))x = self.sigmoid(self.fc2(x))x = self.fc3(x)return x
# 实例化模型
model = LeNet().to(device)
summary(model, (1, 28, 28))

5.训练函数

def train(model, lr, epochs):# 将模型放入GPUmodel = model.to(device)# 使用交叉熵损失函数loss_fn = nn.CrossEntropyLoss().to(device)# SGDoptimizer = torch.optim.SGD(model.parameters(), lr=lr)# 记录训练与验证数据train_losses = []train_accuracies = []# 开始迭代   for epoch in range(epochs):   # 切换训练模式model.train()  # 记录变量train_loss = 0.0correct_train = 0total_train = 0# 读取训练数据并使用 tqdm 显示进度条for i, (inputs, targets) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch+1}/{epochs}", unit='batch'):# 训练数据移入GPUinputs = inputs.to(device)targets = targets.to(device)# 模型预测outputs = model(inputs)# 计算损失loss = loss_fn(outputs, targets)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 使用优化器优化参数optimizer.step()# 记录损失train_loss += loss.item()# 计算训练正确个数_, predicted = torch.max(outputs, 1)total_train += targets.size(0)correct_train += (predicted == targets).sum().item()# 计算训练正确率并记录train_loss /= len(train_dataloader)train_accuracy = correct_train / total_traintrain_losses.append(train_loss)train_accuracies.append(train_accuracy)# 输出训练信息print(f"Epoch [{epoch + 1}/{epochs}] - Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}")# 绘制损失和正确率曲线plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.plot(range(epochs), train_losses, label='Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.subplot(1, 2, 2)plt.plot(range(epochs), train_accuracies, label='Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.show()

6.模型训练

model = LeNet()
lr = 0.9 # sigmoid两端容易饱和,gradient比较小,学得比较慢,所以学习率要大一些
epochs = 20
train(model,lr,epochs)

7.模型测试 

def test(model, test_dataloader, device, model_path):# 将模型设置为评估模式model.eval()# 将模型移动到指定设备上model.to(device)# 从给定路径加载模型的状态字典model.load_state_dict(torch.load(model_path))correct_test = 0total_test = 0# 不计算梯度with torch.no_grad():# 遍历测试数据加载器for inputs, targets in test_dataloader:  # 将输入数据和标签移动到指定设备上inputs = inputs.to(device)targets = targets.to(device)# 模型进行推理outputs = model(inputs)# 获取预测结果中的最大值_, predicted = torch.max(outputs, 1)total_test += targets.size(0)# 统计预测正确的数量correct_test += (predicted == targets).sum().item()# 计算并打印测试数据的准确率test_accuracy = correct_test / total_testprint(f"Accuracy on Test: {test_accuracy:.4f}")return test_accuracy
model_path = save_path
test(model, test_dataloader, device, save_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/593233.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL5.7控制复制源服务器的SQL语句

官网地址:MySQL :: MySQL 5.7 Reference Manual :: 13.4.1 SQL Statements for Controlling Replication Source Servers 欢迎关注留言,我是收集整理小能手,工具翻译,仅供参考,笔芯笔芯. MySQL 5.7 参考手册 / ... …

Vue:Vue 3.4 新特性

Vue 3.4 是 Vue.js 的一个重要更新,它在性能、API 和能上带来了许多改进和增强。 更好的 TypeScript 支持 Vue 3.4 进一步加强了 TypeScript 的集成,提供了更好的类型推断和更丰富的类型定义,使得使用 TypeScript 开发 Vue 应用变得更加顺畅…

看懂 Git Graph

目录 文章目录 目录Git Graph看懂 GraphVSCode Git Graph 插件1. 选择展示的 Branches2. Checkout 到一个 Branch3. 找到指定 Branch 最新的 Commit4. 找到 Branch 分叉口5. 查看 2 个 Commits 之前的区别 Git Graph Git Graph 是服务于 Git 分支管理的一种可视化工具&#xf…

大文件断点下载Range下载zip包显示文件损坏

问题:大文件下载,其它格式的文件及rar格式的压缩包正常下载但是 之后zip包下载后解压失败 原因分析: 1. 查看上传文件的属性值 如图,10.4kb是已经约去小数点的值,准确的大小应该是10663字节10.4130859375KB,所以用10.…

VS Code 远程连接云机器训练配置

VS Code 远程连接云机器 Visual Studio Code(以下简称 VS Code)是一个由微软开发的代码编辑器。VS Code 支持代码补全、代码片段、代码重构、Git 版本控制等功能。 安装 VSCode步骤简单且网上有很多教程,这里不过多重复了。 VS Code 现已支…

Nacos配置回滚

前言 很多时候,我们会配置错一些属性,或者需要回滚某些属性,这时候使用Nacos的回滚功能就很方便了 配置回滚 1、在控制台中,选择左侧导航栏的 “配置管理”,进入历史版本,选择Group和data id&#xff0c…

掌握ROS:完整的认识ROS

目录 前景知识补充 什么是元操作系统? 高层决策 高层决策的特点 什么是ROS? ROS的框架 1. 通信机制 系统工具 开发框架 ROS的开发流程 ROS的应用场景 前景知识补充 什么是元操作系统? 元操作系统(Meta-Operating Sys…

opencv C++透视变换

#include <opencv2/opencv.hpp> #include <iostream> int main() {// 1. 读取图像 cv::Mat image cv::imread("C:/Users/10623/Pictures/adf4d0d56444414cbeb809f0933b9214.png");if (image.empty()) {std::cerr << "无法读取图像!"…

PTA找出不是两个数组共有的元素(C语言)

读题&#xff1a; 给定两个整型数组&#xff0c;本题要求找出不是两者共有的元素。 输入格式: 输入分别在两行中给出两个整型数组&#xff0c;每行先给出正整数N&#xff08;≤20&#xff09;&#xff0c;随后是N个整数&#xff0c;其间以空格分隔。 输出格式: 在一行中按…

以源码为驱动:Java版工程项目管理系统平台助力工程企业迈向数字化管理的巅峰

随着企业规模的不断扩大和业务的快速发展&#xff0c;传统的工程项目管理方式已经无法满足现代企业的需求。为了提高工程管理效率、减轻劳动强度、提高信息处理速度和准确性&#xff0c;企业需要借助先进的数字化技术进行转型。本文将介绍一款采用Spring CloudSpring BootMybat…

vr眼镜和AR眼镜的区别有哪些?哪些产品可以支持VR应用?

vr眼镜怎么连接手机 要将VR眼镜连接到手机上&#xff0c;您可以按照以下步骤进行&#xff1a; 1. 确保您的手机支持VR应用程序&#xff1a;首先&#xff0c;确保您的手机具备运行VR应用程序的硬件和软件条件。一些VR应用程序可能对设备有特定的要求&#xff0c;如处理器性能、操…

【Java并发】深入浅出 synchronized关键词原理-上

一个问题的思考 建设我们有两个线程&#xff0c;一个进行5000次的相加操作&#xff0c;另一个进行5000次的减操作。那么最终结果是多少 package com.jia.syn;import java.util.concurrent.TimeUnit;/*** author qxlx* date 2024/1/2 10:08 PM*/ public class SynTest {privat…

Sharding Sphere 教程 简介

一 文档简介 1.1 分库分表诞生的前景 随着系统用户运行时间还有用户数量越来越多&#xff0c;整个数据库某些表的体积急剧上升&#xff0c;导致CRUD的时候性能严重下降&#xff0c;还容易造成系统假死。 这时候系统都会做一些基本的优化&#xff0c;比如加索引…

高德地图经纬度坐标导出工具

https://tool.xuexiareas.com/map/amap 可以导出单个点&#xff0c;也可以导出多个&#xff0c;多个点可以连成线&#xff0c;可用于前端开发时自己模拟“线“数据

基于Springboot的服务端开发脚手架-自动生成工具

继之前的 专题系列课程&#xff1a; ​​从零开始搭建grpc分布式应用​​完整DEMO&#xff1a;​​基于Springboot的Rpc服务端开发脚手架(base-grpc-framework)​​ 后带来一款项目自动手成工具&#xff08;由于包路径等原因&#xff0c;完整demo想应用在实际开发中需要改很多代…

编程笔记 html5cssjs 008 HTML图片

编程笔记 html5&css&js 008 HTML图片 一、HTML 图像二、HTML 图像- Alt属性三、HTML 图像- 设置图像的高度与宽度四、支持的图像格式五、对齐设置六、操作小结 和印刷品一样&#xff0c;网页上不仅要有文字&#xff0c;还要有图片和声音、视频、绘画等媒体形式。这一节…

广州求职招聘(找工作)去哪里找比较好

在广州找工作&#xff0c;可以选择“吉鹿力招聘网”这个平台。它是一个号称直接和boss聊的互联网招聘神器&#xff0c;同时&#xff0c;“吉鹿力招聘网”作岗位比较齐全&#xff0c;企业用户也多&#xff0c;比较全面。在“吉鹿力招聘网”历即可投递岗位。 广州找工作上 吉鹿力…

学习的记录

一、内核安装 1.安装内核编译工具 install gcc gcc-c ncurses ncurses-devel cmake elfutils-libelf-devel openssl-devel 将内核源码linux-4.14.160.tar拷贝到/usr/src/kernels目录下 cp -r linux-4.20.2.tar.xz /usr/src/kernels 2.进入内核源码所在的文件夹&#xff1…

结构体与函数简单总结(依靠洛谷结构体题与函数题单)

函数结构体简单总结 依靠洛谷函数与字符串题单 文章目录 函数结构体简单总结前言一、函数1、有返回值的函数2、无返回值函数3、递归函数 二、结构体总结 前言 之前总结了字符串的简单应用&#xff0c;随着函数与结构体的题单完成&#xff0c;入门题单也就刷完了&#xff0c;现…

你可能不知道的5款好用封面设计工具,快来一探究竟吧!

我相信每个作者和出版商都希望在一部作品完成后有一个醒目的封面&#xff0c;这样潜在的读者就会有足够的好奇心拿起这本书&#xff0c;你的书的销量就会上升。这就是封面设计软件的使用&#xff0c;专业的封面设计软件可以增加前沿效果&#xff0c;呈现最适合书籍内容的创意布…