LeNet网络分析与demo实例

参考自 

  • up主的b站链接:霹雳吧啦Wz的个人空间-霹雳吧啦Wz个人主页-哔哩哔哩视频
  • 这位大佬的博客 Fun'_机器学习,pytorch图像分类,工具箱-CSDN博客

网络分析:

最好是把这个图像和代码对着来看然后进行分析的时候比较快

# 使用torch.nn包来构建神经网络.
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module): 					# 继承于nn.Module这个父类def __init__(self):						# 初始化网络结构super(LeNet, self).__init__()    	# 多继承需用到super函数self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):			 # 正向传播过程x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x)            # output(16, 14, 14)x = F.relu(self.conv2(x))    # output(32, 10, 10)x = self.pool2(x)            # output(32, 5, 5)x = x.view(-1, 32*5*5)       # output(32*5*5)x = F.relu(self.fc1(x))      # output(120)x = F.relu(self.fc2(x))      # output(84)x = self.fc3(x)              # output(10)return x

Conv1:

输入矩阵  : 32*32*3

卷积层:5*5*3   16个

输出:(32-5)/1+1 = 28    28*28*16

MaxPool:

输入矩阵:28*28*16

2*2最大下采样

输出:14*14*16

Conv2:

输入矩阵  : 14*14*16

卷积层:5*5*16   32个

输出:(14-5)/1+1 = 10    10*10*32

MaxPool:

输入矩阵:10*10*32

2*2最大下采样

输出:5*5*32

全连接Linear:

Linear(32*5*5,120)

Linear(120,84)

Linear(120,10)最后这个数字要取决于你要分几类

经卷积后的输出层尺寸计算公式为:

Output= (W−F+2P)/S​+1

输入图片大小 W×W(一般情况下Width=Height)
Filter大小 F×F
步长 S
padding的像素数 P

经过上述分析就可以pytorch构建网络了:

import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Reshape(nn.Module):def forward(self,x):return x.view(-1,1,28,28)class LeNet(nn.Module):def __init__(self):super(LeNet,self).__init__()self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)#输入数据的维度 卷积核的个数(即输出数据的维度) 卷积核的大小5*5self.pool1 = nn.MaxPool2d(2,2)#池化层的大小,步长self.conv2 = nn.Conv2d(6,16,5)self.pool2 = nn.MaxPool2d(2,2)self.fc1 = nn.Linear(16*5*5,120)self.fc2 = nn.Linear(120,84)self.fc3 = nn.Linear(84,10)def forward(self,x):x = F.relu(self.conv1(x))       #input(1,28,28) output(6,28,28)x = self.pool1(x)               # output(6,14,14) 池化层不改变数据的维度 (w - f + 2*p)/s+1 w为输入图片大小,f为卷积核的大小,p为填充否,s为步长 计算可得输出图片的大小x = F.relu(self.conv2(x))       #output(16,10,10)x = self.pool2(x)               #output(16,5,5)x = x.view(-1,16*5*5)           #output(16*5*5) 展成一列向量进行操作x = F.relu(self.fc1(x))         #output(120)x = F.relu(self.fc2(x))         #output(84)x = F.relu(self.fc3(x))         #output(10)return ximport torch
input1 = torch.rand([32,1,28,28])
model = LeNet()
print(model)
output = model(input1)

数据集介绍

利用torchvision.datasets函数可以在线导入pytorch中的数据集,包含一些常见的数据集如MNIST等

# 导入10000张测试图片
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,	# 表示是数据集中的测试集download=False,transform=transform)
# 加载测试集
test_loader = torch.utils.data.DataLoader(test_set, batch_size=10000, # 每批用于验证的样本数shuffle=False, num_workers=0)
# 获取测试集中的图像和标签,用于accuracy计算
test_data_iter = iter(test_loader)
test_image, test_label = test_data_iter.next()

2.2 训练过程

epoch   : 对训练集的全部数据进行一次完整的训练,称为 一次 epoch
batch   : 由于硬件算力有限,实际训练时将训练集分成多个批次训练,每批数据的大小为 batch_size
iteration 或 step :   对一个batch的数据训练的过程称为 一个 iteration 或 step

训练过程

net = LeNet()						  				# 定义训练的网络模型
loss_function = nn.CrossEntropyLoss() 				# 定义损失函数为交叉熵损失函数 
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器(训练参数,学习率)for epoch in range(5):  # 一个epoch即对整个训练集进行一次训练running_loss = 0.0time_start = time.perf_counter()for step, data in enumerate(train_loader, start=0):   # 遍历训练集,step从0开始计算inputs, labels = data 	# 获取训练集的图像和标签optimizer.zero_grad()   # 清除历史梯度# forward + backward + optimizeoutputs = net(inputs)  				  # 正向传播loss = loss_function(outputs, labels) # 计算损失loss.backward() 					  # 反向传播optimizer.step() 					  # 优化器更新参数# 打印耗时、损失、准确率等数据running_loss += loss.item()if step % 1000 == 999:    # print every 1000 mini-batches,每1000步打印一次with torch.no_grad(): # 在以下步骤中(验证过程中)不用计算每个节点的损失梯度,防止内存占用outputs = net(test_image) 				 # 测试集传入网络(test_batch_size=10000),output维度为[10000,10]predict_y = torch.max(outputs, dim=1)[1] # 以output中值最大位置对应的索引(标签)作为预测输出accuracy = (predict_y == test_label).sum().item() / test_label.size(0)print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %  # 打印epoch,step,loss,accuracy(epoch + 1, step + 1, running_loss / 500, accuracy))print('%f s' % (time.perf_counter() - time_start))        # 打印耗时running_loss = 0.0print('Finished Training')# 保存训练得到的参数
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)


 

测试:

# 导入包
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet# 数据预处理
transform = transforms.Compose([transforms.Resize((32, 32)), # 首先需resize成跟训练集图像一样的大小transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 导入要测试的图像(自己找的,不在数据集中),放在源文件目录下
im = Image.open('horse.jpg')
im = transform(im)  # [C, H, W]
im = torch.unsqueeze(im, dim=0)  # 对数据增加一个新维度,因为tensor的参数是[batch, channel, height, width] # 实例化网络,加载训练好的模型参数
net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))# 预测
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].data.numpy()
print(classes[int(predict)])

输出即为预测的标签。

其实预测结果也可以用 softmax 表示,输出10个概率:

with torch.no_grad():outputs = net(im)predict = torch.softmax(outputs, dim=1)
print(predict)

输出结果中最大概率值对应的索引即为 预测标签 的索引。

tensor([[2.2782e-06, 2.1008e-07, 1.0098e-04, 9.5135e-05, 9.3220e-04, 2.1398e-04,3.2954e-08, 9.9865e-01, 2.8895e-08, 2.8820e-07]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/241797.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Go 泛型之类型参数

Go 泛型之类型参数 文章目录 Go 泛型之类型参数一、Go 的泛型与其他主流编程语言的泛型差异二、返回切片中值最大的元素三、类型参数(type parameters)四、泛型函数3.1 泛型函数的结构3.2 调用泛型函数3.3 泛型函数实例化(instantiation&…

WARNING: HADOOP_SECURE_DN_USER has been replaced by HDFS_DATANODE_SECURE_USER.

Hadoop启动时警告,但不影响使用,强迫症的我还是决定寻找解决办法 WARNING: HADOOP_SECURE_DN_USER has been replaced by HDFS_DATANODE_SECURE_USER. Using value of HADOOP_SECURE_DN_USER.原因是Hadoop安装配置于root用户下,对文件需要进…

案例144:基于微信小程序的自修室预约系统

文末获取源码 开发语言:Java 框架:SSM JDK版本:JDK1.8 数据库:mysql 5.7 开发软件:eclipse/myeclipse/idea Maven包:Maven3.5.4 小程序框架:uniapp 小程序开发软件:HBuilder X 小程序…

Spring中的上下文工具你写的可能有bug

文章目录 前言功能第一种:ApplicationContext第二种方式:ApplicationContextAware第三种:BeanFactoryPostProcessor 源码第一种第二种第三种 前言 本篇是针对如何写一个比较好的spring工具的一个探讨。 功能 下面三种方式,你觉…

Odoo16 实用功能之Form视图详解(表单视图)

目录 1、什么是Form视图 2、Form视图的结构 3、源码示例 1、什么是Form视图 Form视图是用于查看和编辑数据库记录的界面。每个数据库模型在Odoo中都有一个Form视图,用于显示该模型的数据。Form视图提供了一个可编辑的界面,允许用户查看和修改数据库记…

[python]用python实现对arxml文件的操作

目录 关键词平台说明一、背景二、方法2.1 库2.2 code 关键词 python、excel、DBC、openpyxl 平台说明 项目Valuepython版本3.6 一、背景 有时候需要批量处理arxml文件(ARXML 文件符合 AUTOSAR 4.0 标准),但是工作量太大,阔以考虑用python。 二、方…

最新版 JESD79-5B,2022年,JEDEC 内存SDRAM规范

本标准定义了DDR5 SDRAM规范,包括特性、功能、交流和直流特性、封装以及球/信号分配。本标准旨在为x4、x8和x16 DDR5 SDRAM设备定义符合JEDEC标准的8 Gb至32 Gb的最低要求。该标准是基于DDR4标准(JESD79-4)和DDR、DDR2、DDR3和LPDDR4标准的一…

智能优化算法应用:基于金枪鱼群算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于金枪鱼群算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于金枪鱼群算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.金枪鱼群算法4.实验参数设定5.算法结果6.…

1856_emacs_calc使用介绍与故事

Grey 全部学习内容汇总: GitHub - GreyZhang/g_org: my learning trip for org-mode 1856_emacs_calc使用介绍与故事 calc是emacs内置的一个计算器,可以提供多种计算表达方式并且可以支持org-mode中的表格功能。 主题由来介绍 我是因为想要了解org-…

采草(动态规划)

先说说我的思路吧 下面是部分聊天记录 赤坂 龍之介 2023/12/22 11:06:04 就像我之前说的那样,我把每一个药草的价值除以时间,得出了新的价值评估标准:采摘这个药草时,每分钟的价值 赤坂 龍之介 2023/12/22 11:07:00 然后排…

2023年小型计算机视觉总结

在过去的十年中,出现了许多涉及计算机视觉(CV)的项目,无论是小型的概念验证项目还是更大规模的生产应用。应用计算机视觉的方法是相当标准化的: 1、定义问题(分类、检测、跟踪、分割)、输入数据(图片的大小和类型、视野)和类别(正是我们想要的) 2、注释…

Python算法例27 对称数

1. 问题描述 对称数是一个旋转180后(倒过来)看起来与原数相同的数,找到所有长度为n的对称数。 2. 问题示例 给出n2,返回["11","69","88&#x…

详解Vue3中的基础路由和动态路由

本文主要介绍Vue3中的基础路由和动态路由。 目录 一、基础路由二、动态路由 Vue3中的路由使用的是Vue Router库,它是一个官方提供的用于实现应用程序导航的工具。Vue Router在Vue.js的核心库上提供了路由的功能,使得我们可以在单页应用中实现页面的切换、…

QT编写应用的界面自适应分辨率的解决方案

博主在工作机上完成QT软件开发(控件大小与字体大小比例正常),部署到客户机后,发现控件大小与字体大小比例失调,具体表现为控件装不下字体,即字体显示不全,推测是软件不能自适应分辨率导致的。 文…

C/C++ 共用体union的应用和struct不同

共用体union是一种数据格式,它能够存储不同的数据类型,但只能同时存储其中的一种类型。也就是说,结构体同时存储int、long和double,共用体只能春初int、long或double,共用体的语法与结构体相似,但含义不同。例如下面的声明&#x…

基于javaSpringbootmysql的小型超市商品展销系统01635-计算机毕业设计项目选题推荐(免费领源码)

摘 要 科技进步的飞速发展引起人们日常生活的巨大变化,电子信息技术的飞速发展使得电子信息技术的各个领域的应用水平得到普及和应用。信息时代的到来已成为不可阻挡的时尚潮流,人类发展的历史正进入一个新时代。在现实运用中,应用软件的工作…

【SpringCloud】-GateWay源码解析

GateWay系列 【SpringCloud】-GateWay网关 一、背景介绍 当一个请求来到 Spring Cloud Gateway 之后,会经过一系列的处理流程,其中涉及到路由的匹配、过滤器链的执行等步骤。今天我们来说说请求经过 Gateway 的主要执行流程和原理是什么吧 二、正文 …

【教3妹学编程-算法题】收集足够苹果的最小花园周长

3妹:“在小小的花园里面挖呀挖呀挖,种小小的种子开小小的花” 2哥 : 3妹也会唱这首儿歌呀, 这首儿歌在五一期间很火啊。 3妹:是呀, 小朋友们都喜欢唱,我这个200多个月的大朋友也喜欢唱,哈哈 2哥…

仅操作一台设备,如何实现本地访问另一个相同网段的私网?

正文共:1034 字 8 图,预估阅读时间:4 分钟 书接上文(地址重叠时,用户如何通过NAT访问对端IP网络?),我们已经通过两台设备的组合配置实现了通过IP地址进行访问。但一般场景中&#xf…