【深度学习】LeNet网络架构

文章目录

  • 什么是LeNet
  • 代码实现网络架构


什么是LeNet

LeNet是一种经典的卷积神经网络,由Yann LeCun等人在1998年提出。它是深度学习中第一个成功应用于手写数字识别的卷积神经网络,并且被认为是现代卷积神经网络的基础。

LeNet模型包含了多个卷积层和池化层,以及最后的全连接层用于分类。其中,每个卷积层都包含了一个卷积操作和一个非线性激活函数,用于提取输入图像的特征。池化层则用于缩小特征图的尺寸,减少模型参数和计算量。全连接层则将特征向量映射到类别概率上。

在这里插入图片描述


代码实现网络架构

如何搭建网络模型参考博客:Pytorch学习笔记(模型训练)
在这里插入图片描述我们采用CIFAR-10数据集进行训练测试,上面网络模型是1个channel的32x32,而我们的数据集是3个channel的32x32,模型结构不变,改变一下输入输出大小。
model.py:

import torch
from torch import nn# 搭建网络模型
class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=0),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=0),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(32 * 5 * 5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, 10),)def forward(self, x):x = self.model(x)return x# 测试
if __name__ == '__main__':leNet = LeNet()input = torch.ones((64, 3, 32, 32))output = leNet(input)print(output.shape)

train.py

import torch.optim
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom learning.lenet.model import LeNet# 1. 数据集
dataset_train = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
dataset_test = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
train_data_size = len(dataset_train)
test_data_size = len(dataset_test)
# 2. 加载数据集
dataloader_train = DataLoader(dataset_train, batch_size=64)
dataloader_test = DataLoader(dataset_test, batch_size=64)# 3. 搭建model
leNet = LeNet()
if torch.cuda.is_available():leNet = leNet.cuda()# 4. 创建损失函数
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():loss_fn = loss_fn.cuda()# 5. 优化器
learning_rate = 0.1
optimizer = torch.optim.SGD(leNet.parameters(), lr=learning_rate)  # 随机梯度下降# 6. 设置训练网络的一些参数
total_train_step = 0  # 记录训练次数
total_test_step = 0  # 训练测试次数
epoch = 5  # 训练轮数# 补充tensorboard
writer = SummaryWriter("../../logs")# 开始训练
for i in range(epoch):print(f"--------第{i+1}轮训练开始--------")# 训练leNet.train()for data in dataloader_train:imgs, targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()outputs = leNet(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:print(f"训练次数:{total_train_step}---loss:{loss.item()}")writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试leNet.eval()total_test_loss = 0  # 总体的误差total_accuracy = 0  # 总体的正确率with torch.no_grad():for data in dataloader_test:imgs, targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()outputs = leNet(imgs)loss = loss_fn(outputs, targets)total_test_loss += loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint(f"整体测试集上的loss:{total_test_loss}")print(f"整体测试集上的准确率:{total_accuracy/test_data_size}")writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("total_accuracy", total_accuracy/test_data_size, total_test_step)total_test_step += 1# 保存每一轮训练的模型torch.save(leNet, f"leNet_{i+1}.pth")print("模式已保存")writer.close()

在这里插入图片描述

5轮训练中,第5轮的准确率是最高的,采用第5轮的模型进行测试:

test.py

import torch
import torchvision.transforms
from PIL import Imagefrom learning.lenet.model import LeNet# 需要测试的图片
image_path = "../../imgs/airplane.png"
image = Image.open(image_path)
image = image.convert('RGB')  # png图片多了一个透明度通道,修改成rgb三个通道
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()])
image = transform(image)
print(image.shape)# 引入网络架构# 读取网络模型  如果保存的模型是通过gpu训练出来的,需要添加 map_location=torch.device("cpu")
model_load = torch.load("leNet_5.pth", map_location=torch.device("cpu"))
# 原有的图片是没有bitch-size的,而我们的输入是需要的
image = torch.reshape(image, (1, 3, 32, 32))
model_load.eval()
with torch.no_grad():outputs = model_load(image)
print(outputs)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')print(classes[outputs.argmax(1)])

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/84128.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spark

Apache Spark是一种快速、通用、可扩展的大数据处理引擎,旨在处理大规模数据集并进行高效的数据分析。与Hadoop MapReduce相比,Spark具有更高的性能和更丰富的功能,可以处理更复杂的数据处理任务。以下是Apache Spark的一些基本概念&#xff…

idea2023+springboot 热部署配置

pom 中配置 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-devtools</artifactId><optional>true</optional></dependency>idea

蓝桥杯每日一题2023.9.21

蓝桥杯2021年第十二届省赛真题-异或数列 - C语言网 (dotcpp.com) 题目描述 Alice 和 Bob 正在玩一个异或数列的游戏。初始时&#xff0c;Alice 和 Bob 分别有一个整数 a 和 b&#xff0c;有一个给定的长度为 n 的公共数列 X1, X2, , Xn。 Alice 和 Bob 轮流操作&#xff0…

Rsync学习笔记2

Rsync&#xff1a; 增量操作&#xff1a; 1&#xff09; server01服务文件变动。 [rootserver03 tp5shop]# rsync -av /usr/local/nginx/html/tp5shop root192.168.17.109:/usr/local/nginx/html/ sending incremental file listsent 88,134 bytes received 496 bytes 177,…

如何评估测试用例的优先级?

评估测试用例的优先级&#xff0c;有助于我们及早发现和解决可能对系统稳定性和功能完整性产生重大影响的问题&#xff0c;助于提高测试质量&#xff0c;提高用户满意度。 如果没有做好测试用例的优先级评估&#xff0c;往往容易造成对系统关键功能和高风险场景测试的忽略&…

黑马JVM总结(十八)

&#xff08;1&#xff09;G1_FullGC的概念辨析 SerialGC&#xff1a;串行的&#xff0c;ParallelGC&#xff1a;并行的 &#xff0c;CMS和G1都是并发的 这几种垃圾回收器的新生代回收机制时相同的&#xff0c;SerialGC和ParalledGC&#xff1a;老年代内存不足触发的叫FullGC…

zabbix自定义监控、钉钉、邮箱报警 (五十六)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 一、实验准备 二、安装 三、添加监控对象 四、添加自定义监控项 五、监控mariadb 1、添加模版查看要求 2、安装mariadb、创建用户 3、创建用户文件 4、修改监控模版 5、…

Vue2的学习

computed计算属性 概念 基于现有数据&#xff0c;计算出来的新属性&#xff0c;依赖的数据变化&#xff0c;会自动重新计算 语法 声明在computed配置项中&#xff0c;一个计算属性对应一个函数这是一个属性{{计算属性名}}&#xff0c;不是方法注意不要忘记return <body…

uniapp开发h5,解决项目启动时,Network: unavailable问题

网上搜了很多&#xff0c;发现都说是要禁用掉电脑多余的网卡&#xff0c;这方法我试了没有好&#xff0c;不晓得为啥子&#xff0c;之后在网上看&#xff0c;uniapp的devServer vue2的话对标的就是webpack4的devserver&#xff08;除了复杂的函数配置项&#xff09;&#xff0c…

持有NPDP证书是否可以进入高薪行业?

对于NPDP&#xff0c;一部分人考完PMP&#xff0c;想要继续提升自己的职场竞争力&#xff0c;一部分人是想直接考NPDP&#xff0c;但不确定它是否真的可以帮助到自己的职业发展&#xff0c;还处于一个犹豫踌躇的状态&#xff0c;那么我就来详细的介绍下NPDP考试的相关讯息&…

HTML5+CSS3+JS小实例:鼠标控制飞机的飞行方向

实例:鼠标控制飞机的飞行方向 技术栈:HTML+CSS+JS 效果: 源码: 【html】 <!DOCTYPE html> <html><head><meta http-equiv="content-type" content="text/html; charset=utf-8"><meta name="viewport" conten…

【unity】关于技能释放shader.CreateGPUProgram造成卡顿,优化和定位方法。

关于优化方法&#xff0c;UWA这边有介绍 Unity移动端游戏性能优化简谱之 CPU耗时调优|单帧|动画|调用|unity|实例化_网易订阅 对此&#xff0c;我们可以将Shader通过ShaderVariantCollection收集要用到的变体并进行AssetBundle打包。在将该ShaderVariantCollection资源加载进内…

KubeSphere 在互联网医疗行业的应用实践

作者&#xff1a;宇轩辞白&#xff0c;运维研发工程师&#xff0c;目前专注于云原生、Kubernetes、容器、Linux、运维自动化等领域。 前言 2020 年我国互联网医疗企业迎来了“爆发元年”&#xff0c;越来越多居民在家隔离期间不方便去医院看诊&#xff0c;只好采取在线诊疗的手…

RK3568驱动指南|第五期-中断-

瑞芯微RK3568芯片是一款定位中高端的通用型SOC&#xff0c;采用22nm制程工艺&#xff0c;搭载一颗四核Cortex-A55处理器和Mali G52 2EE 图形处理器。RK3568 支持4K 解码和 1080P 编码&#xff0c;支持SATA/PCIE/USB3.0 外围接口。RK3568内置独立NPU&#xff0c;可用于轻量级人工…

瑞芯微:基于RK3568的ocr识别

光学字符识别&#xff08;Optical Character Recognition, OCR&#xff09;是指对文本资料的图像文件进行分析识别处理&#xff0c;获取文字及版面信息的过程。亦即将图像中的文字进行识别&#xff0c;并以文本的形式返回。OCR的应用场景 卡片证件识别类&#xff1a;大陆、港澳…

如何通过AI视频智能分析技术,构建着装规范检测/工装穿戴检测系统?

众所周知&#xff0c;规范着装在很多场景中起着重要的作用。违规着装极易增加安全隐患&#xff0c;并且引发安全事故和质量问题&#xff0c;例如&#xff0c;在化工工厂中&#xff0c;倘若员工没有穿戴符合要求的特殊防护服和安全鞋&#xff0c;将有极大可能受到有害物质的侵害…

Doxygen在vs code配置

找到这个 就在这里面配置&#xff0c;如果要在原有的下面添加别忘了后面加个逗号&#xff0c;我在他前面加的所以我在上面加了个 //基础设置 “doxdocgen.c.triggerSequence”: “/", “doxdocgen.c.firstLine”: "/", “doxdocgen.c.commentPrefix”: &quo…

〔023〕Stable Diffusion 之 界面主题 篇

✨ 目录 🎈 系统内置主题🎈 kitchen Theme 主题🎈 Catppuccin Theme 主题🎈 Cozy Nest 主题🎈 系统内置主题 可以通过命令行修改主题,在 webui-user.bat 文件中 set COMMANDLINE_ARGS 参数后面添加 --theme dark 来设置深色主题当然,系统设置里面也自带了很多的主题…

C++ 异常处理学习笔记

一、使用情况 1、数组越界&#xff1a;包括数组索引小于0&#xff0c;或者大于数组长度 2、空指针 可以抛出(throw)各种类型的异常&#xff0c;catch的地方接收就可以

一款非常容易上手的报表工具,简单操作实现BI炫酷界面数据展示,驱动支持众多不同类型的数据库,可视化神器,免开源了

一款非常容易上手的报表工具&#xff0c;简单操作实现BI炫酷界面数据展示&#xff0c;驱动支持众多不同类型的数据库&#xff0c;可视化神器&#xff0c;免开源了。 在互联网数据大爆炸的这几年&#xff0c;各类数据处理、数据可视化的需求使得 GitHub 上诞生了一大批高质量的…