6.6 实现卷积神经网络LeNet训练并预测手写体数字

模型架构

在这里插入图片描述
在这里插入图片描述

代码实现

import torch
from torch import nn
from d2l import torch as d2l
net = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),#padding=2补偿5x5卷积核导致的特征减少。nn.AvgPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(16*5*5,120),nn.Sigmoid(),nn.Linear(120,84),nn.Sigmoid(),nn.Linear(84,10)
)
'''定义X,并打印模型的形状'''
# 第一个参数是样本
X = torch.rand(size=(1,1,28,28),dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)
# 输出如下:
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])
'''定义训练批次并加载训练集和测试集'''
batch_size = 256
# 按照batch_size把数据集取出来。取出来之后是放到内存中的,后面要把它加载到GPU中
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
# 计算预测正确的个数
def accuracy(y_hat,y):'''计算预测正确的数量'''if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:# y_hat是下标表示类别,值是该类别的概率。模型结果是预测10个类的概率,谁的概率最大,就取谁的下标y_hat = y_hat.argmax(axis=1)#  y_hat.type(y.dtype):因为==对数据类型很敏感,因此我们将y_hat的数据类型转换为与y的数据类型一致。#  y_hat.type(y.dtype) == y,将预测值y_hat与真实值y比较,返回一个包含 0和1的张量,赋值给cam,最后求和会得到正确预测的数量。cam = y_hat.type(y.dtype) == yreturn float(cam.type(y.dtype).sum())
def evaluate_accuracy_gpu(net,data_iter,device=None):if isinstance(net,nn.Module):net.eval() # 将模型设置为评估模式if not device:'''iter(net.parameters())是将参数集合转换为迭代器,并获取其中的第一个元素next(iter(net.parameters())).device ,指取到net.parameters()的第一个元素,获取该元素的设备。'''device = next(iter(net.parameters())).device# Accumulator用于对多个变量进行累加,d2l.Accumulator(2) 是在Accumulator实例中创建了2个变量,分别用于存储正确预测的数量和预测的总数量。当我们遍历数据集时,两者都随着时间的推移而累加。metric = d2l.Accumulator(2) # 正确预测数,预测总数with torch.no_grad():for X,y in data_iter:if isinstance(X,list): # 详见文章最下面的补充内容X = [x.to(device) for x in X] # 令X使用设备deviceelse:X = X.to(device)y = y.to(device)# y.numel()是批次中样本的数量,accuracy(net(X),y)是用于计算模型在输入数据X上的输出结果与标签Y之间的准确率。# metric.add函数将正确预测的数量 和 样本数量作为参数传递进去,用于记录和累计这些指标的值。metric.add(accuracy(net(X),y),y.numel())return metric[0]/metric[1] # 返回准确率,其中metric[0]存放的是正确预测的个数,metric[1]存放的是样本数量,
def train_ch6(net,train_iter,test_iter,num_epochs,lr,device):def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d: # 对神经网络中的线性层和卷积层的权重进行初始化nn.init.xavier_uniform_(m.weight) #用于初始化权重的函数,net.apply(init_weights)print('training on',device)net.to(device) # 设置模型使用deviceoptimizer = torch.optim.SGD(net.parameters(),lr=lr)loss = nn.CrossEntropyLoss()'''该代码创建了一个名为animator的动画器,用于在训练过程中可视化损失函数和准确率的变化情况'''animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=['train loss','train acc','test acc'])timer,num_batches = d2l.Timer(),len(train_iter)for epoch in range(num_epochs):# 创建 Accumulator类,统计训练损失之和,正确预测个数之和,样本数metric = d2l.Accumulator(3)net.train()for i,(X,y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X,y = X.to(device),y.to(device)y_hat = net(X)l = loss(y_hat,y)l.backward()optimizer.step()with torch.no_grad():metric.add(l*X.shape[0], d2l.accuracy(y_hat,y), X.shape[0])timer.stop()train_l = metric[0] / metric[2] # 损失之和 / 样本数train_acc = metric[1] / metric[2] # 正确预测个数 / 样本数if (i+1) % (num_batches//5)==0 or i == num_epochs-1:animator.add(epoch + (i+1)/num_epochs,(train_l,train_acc,None))test_acc = evaluate_accuracy_gpu(net,test_iter)animator.add(epoch+1,(None,None,test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')
# 定义学习率和批次 开始训练
lr,num_epochs = 0.9,10
train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())

在这里插入图片描述

练习

把平均汇聚层改为最大汇聚层

在这里插入图片描述

把平均池化改为最大池和把激活函数改为RelU之后的效果

net = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.ReLU(),#padding=2补偿5x5卷积核导致的特征减少。nn.MaxPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(16*5*5,120),nn.ReLU(),nn.Linear(120,84),nn.Sigmoid(), #注意,此处不能改为RelU,此处的sigmoid是把预测结果映射成概率nn.Linear(84,10)
)

在这里插入图片描述

使用训练好的模型进行预测

y_hat = net(x)

补充:

isinstance(net,nn.Module)

isinstance(net,nn.Module)是Python的内置函数,用于判断一个对象是否属于制定类或其子类的实例。如果net是nn.Module类或子类的实例,那么表达式返回True,否则返回False. nn.Module是pytorch中用于构建神经网络模型的基类,其他神经网络都会继承它,因此使用 isinstance(net,nn.Module),可以确定Net对象是否为一个有效的神经网络模型。

`nn.init.xavier_uniform_(m.weight)

nn.init.xavier_uniform_(m.weight) 是一个用于初始化权重的函数,采用的是 Xavier 均匀分布初始化方法。

在神经网络中,权重的初始化非常重要,合适的初始化可以帮助网络更好地学习和收敛。Xavier 初始化方法是一种常用的权重初始化方法之一,旨在使权重在前向传播过程中保持方差不变。

具体而言,nn.init.xavier_uniform_() 函数会对输入的权重张量 m.weight 进行操作,将其初始化为一个均匀分布中的随机值。这个均匀分布的范围根据权重张量的形状进行调整,以保持前向传播过程中特征的方差稳定。

通过使用 Xavier 初始化方法,可以加速神经网络的训练过程,并且有助于避免梯度消失或梯度爆炸等问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/26855.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenStreetMap数据转3D场景【Python + PostgreSQL】

很长一段时间以来,我对 GIS 和渲染感兴趣,在分别尝试这两者之后,我决定最终尝试以 3D 方式渲染 OpenStreetMap 中的地理数据,重点关注不超过城市的小规模。 在本文中,我将介绍从建筑形状生成三角形网格、以适合 Blend…

iperf3-性能测试

iperf3-性能测试 安装1.apt安装2.源码安装 使用方法iperf原理测试参考文档性能测试客户端服务端 官方文档:https://iperf.fr/iperf-doc.php 安装 1.apt安装 sudo apt-get install iperf32.源码安装 # 按照官方说明安装 ./configure make sudo make install执行编…

GATK ApplyBQSRSpark 过程中因No space left on device终止

Error: GATK ApplyBQSRSpark 过程中因No space left on device终止 执行命令: nohup time ./gatk --java-options "-Xmx128G" ApplyBQSRSpark --spark-master local[20] -R ../../alignment/hg38/hg38.fa -I ../../alignment/bam/P368T.s…

微信小程序nodejs+vue+uniapp高校食堂线上预约点餐系统

本次设计任务是要设计一个食堂线上预约点餐系统,通过这个系统能够满足管理员及学生的食堂线上预约点餐分享功能。系统的主要包括首页、个人中心、学生管理、菜品分类管理、菜品管理、关于我们管理、意见反馈、系统管理、订单管理等功能。 开发语言 node.js 框架&am…

【论文阅读】对抗溯源图主机入侵检测系统的模仿攻击(NDSS-2023)

作者:伊利诺伊大学芝加哥分校-Akul Goyal、Gang Wang、Adam Bates;维克森林大学-Xueyuan Han、 引用:Goyal A, Han X, Wang G, et al. Sometimes, You Aren’t What You Do: Mimicry Attacks against Provenance Graph Host Intrusion Detect…

BenchmarkSQL 支持 TiDB 驱动以及 tidb-loadbalance

作者: GangShen 原文来源: https://tidb.net/blog/3c274180 使用 BenchmarkSQL 对 TiDB 进行 TPC-C 测试 众所周知 TiDB 是一个兼容 MySQL 协议的分布式关系型数据库,用户可以使用 MySQL 的驱动以及连接方式连接 TiDB 进行使用&#xff0…

Git从远程仓库中删除文件,并上传新文件

目录 删除: 拉取远程分支的更新: ​编辑 首先查看git状态: ​编辑 删除文件并提交版本库: 提交: 上传新文件: 首先查看git状态: 提交到暂存区: 提交到版本库: 上…

基于Spring Boot的在线视频教育培训网站设计与实现(Java+spring boot+MySQL)

获取源码或者论文请私信博主 演示视频: 基于Spring Boot的在线视频教育培训网站设计与实现(Javaspring bootMySQL) 使用技术: 前端:html css javascript jQuery ajax thymeleaf 微信小程序 后端:Java sp…

skywalking日志收集

文章目录 一、介绍二、添加依赖三、修改日志配置1. 添加链路表示traceId2. 添加链路上下文3. 异步日志 四、收集链路日志 一、介绍 在上一篇文章skywalking全链路追踪中我们介绍了在微服务项目中使用skywalking进行服务调用链路的追踪。 本文在全链路追踪的基础上&#xff0c…

gradle项目Connection timed out,build时先下载gradle问题download gradle-x.x-bin.zip

IDEA 导入 Gradle 项目,编译的时候会默认下载 配置版本的Gradle.zip问题,一般会下载失败,提示Connection timed out,连接超时。 解决办法: 修改项目根目录下gradle目录下的gradle-wrapper.properties文件,…

LLM as Co-pilot:AutoDev 1.0 发布,开源全流程 AI 辅助编程

四月,在那篇《AutoDev:AI 突破研发效能,探索平台工程新机遇》,我们初步拟定了 AI 对于研发的影响。我们有了几个基本的假设: 中大型企业将至少拥有一个私有化的大语言模型。只有构建端到端工具才能借助 AI 实现增质提效…

STM32入门——定时器

内容为江科大STM32标准库学习记录 TIM简介 TIM(Timer)定时器定时器可以对输入的时钟进行计数,并在计数值达到设定值时触发中断16位计数器、预分频器、自动重装寄存器的时基单元,在72MHz计数时钟下可以实现最大59.65s的定时&…

【Docker】性能测试监控平台搭建:InfluxDB+Grafana+Jmeter+cAdvisor

前言 在做性能测试时,如果有一个性能测试结果实时展示的页面,可以极大的提高我们对系统性能表现的掌握程度,进而提高我们的测试效率。但是我们每次打开Jmeter都会有几个硕大的字提示别用GUI模式进行负载测试,而且它自带的监视器效…

系统架构设计师-软件架构设计(7)

目录 大型网站系统架构演化 一、第一阶段:单体架构 到 第二阶段:垂直架构 二、第三阶段:使用缓存改善网站性能 1、缓存与数据库的数据一致性问题 2、缓存技术对比【MemCache与Redis】 3、Redis分布式存储方案 4、Redis集群切片的常见方式 …

c++ boost circular_buffer

boost库中的 circular_buffer顾名思义是一个循环缓冲器,其 capcity是固定的当容量满了以后,插入一个元素时,会在容器的开头或结尾处删除一个元素。 circular_buffer为了效率考虑,使用了连续内存块保存元素 使用固定内存&#x…

为什么要选择文件传输软件?有哪些最佳高速文件传输软件?

是否经历过这样的场景,正在努力地完成工作任务,但是由于制作的数据无法及时传送给合作伙伴,工作流程被打断了?这听起来很令人沮丧,对吧?可是,这种情况在现实中并不罕见。 因此,需要…

OpenCv.js(图像处理)学习历程

opencv.js官网 4.5.0文档 以下内容整理于opencv.js官网。 简介 OpenCV由Gary Bradski于1999年在英特尔创建。第一次发行是在2000年。OpenCV支持c、Python、Java等多种编程语言,支持Windows、Linux、Os X、Android、iOS等平台。基于CUDA和OpenCL的高速GPU操作接口也…

java泛型和通配符的使用

泛型机制 本质是参数化类型(与方法的形式参数比较,方法是参数化对象)。 优势:将类型检查由运行期提前到编译期。减少了很多错误。 泛型是jdk5.0的新特性。 集合中使用泛型 总结: ① 集合接口或集合类在jdk5.0时都修改为带泛型的结构② 在实例化集合类时…

Unity数字可视化学校_昼夜(三)

1、删除不需要的 UI using System.Collections; using System.Collections.Generic; using UnityEngine; using UnityEngine.UI;public class EnvControl : MonoBehaviour {//UIprivate Button btnTime;private Text txtTime; //材质public List<Material> matListnew Li…

docker中的jenkins去配置sonarQube

docker中的jenkins去配置sonarQube 1、拉取sonarQube macdeMacBook-Pro:~ mac$ docker pull sonarqube:8.9.6-community 8.9.6-community: Pulling from library/sonarqube 8572bc8fb8a3: Pull complete 702f1610d53e: Pull complete 8c951e69c28d: Pull complete f95e4f8…