23/76-LeNet

LeNet
早期成功的神经网络。
先使用卷积层来学习图片空间信息。
然后使用全连接层转换到类别空间。

在这里插入图片描述

#In[]
'''
LeNet,上世纪80年代的产物,最初为了手写识别设计
'''
from d2l import torch as d2l
import torch 
from torch import nn
from torch.nn.modules.loss import CrossEntropyLossfrom torch.utils import data
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import Common_functions'''
LeNet:
两个卷积层,两个池化层,三个线性层
假定为MNIST设计,输入为(batch_size,1,28,28)
'''class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)net = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5,5),padding=2),nn.Sigmoid(), #输出:(6,28,28)nn.AvgPool2d(kernel_size=(2,2)), #不指定stride默认不重叠 输出(6,14,14)nn.Conv2d(6,16,kernel_size=(5,5)),nn.Sigmoid(),#输出(16,10,10)nn.AvgPool2d(kernel_size=(2,2)),#输出(16,5,5)nn.Flatten(),nn.Linear(16*5*5,120),nn.Sigmoid(),#nn.Linear(120,84),nn.Sigmoid(),nn.Linear(84,10)
)X=torch.rand(size=(1,1,28,28),dtype=torch.float32)
for layer in net:X=layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)#In[]batch_size = 256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size=batch_size)#对evaluate_accuracy函数进行轻微修改
#使用GPU计算模型在数据集上的精度
#计算网络在测试数据集上面的准确率
#由于完整的测试数据集位于内存中,因此在模型使用GPU预测测试数据集之前,我们需要将其复制到显存中。
def evaluate_accuracy_gpu(net,data_iter,device=None):if isinstance(net,nn.Module):net.eval() #网络用于测试数据if not device:device = next(iter(net.parameters())).device #如果没有指定device设备,device设备则使用第一层网络参数的设备accumulator = d2l.Accumulator(2) #累加器里面包含两个元素for X,y in data_iter:if isinstance(X,list):X = [x.to(device) for x in X] #X为list类型时,需要加X里面每个元素都复制到device设备上面来else:X = X.to(device)y = y.to(device)accumulator.add(d2l.accuracy(net(X),y),y.numel()) #累加器第一个元素为在每一个batch_size中预测准确的个数,第二个元素为每一个batch_size中样本总数目,然后依次循环累加,得到测试数据集上面预测准确的总数目,以及数据集总数目return accumulator[0]/accumulator[1] #算出模型预测准确率def train_ch6(net,train_iter,test_iter,num_epochs,lr,device):def init_weights(m):#手动初始化模型参数if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight) #使用xavier_uniform分布初始化参数net.apply(init_weights)net.to(device)#将模型复制到gpu上面print('training on',device)loss = nn.CrossEntropyLoss() #定义lossoptim = torch.optim.SGD(net.parameters(),lr=lr) #定义优化器animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=['train_loss','train_acc','test_acc'])timer = d2l.Timer()num_batches = len(train_iter)for epoch in range(num_epochs):net.train()#模型开始训练,需要放在第一层循环里面,因为后面evaluate_accuracy_gpu()函数里面有net.eval(),将模型改变为测试状态,因此需要在每一个循环epoch后面手动再加上模型开始处于训练状态accumulator = d2l.Accumulator(3) #累加器for i,(X,y) in enumerate(train_iter):timer.start()optim.zero_grad()X = X.to(device)#将X复制到gpu上面y = y.to(device) #将y复制到gpu上面y_hat = net(X) #得到模型训练后的输出标签y_hatl = loss(y_hat,y)#计算每一个batch_size的lossl.backward() #计算梯度optim.step() #使用优化器更新模型参数with torch.no_grad():#不需要模型梯度accumulator.add(l*X.shape[0],d2l.accuracy(y_hat,y),X.shape[0])timer.stop()train_loss = accumulator[0]/accumulator[2] #从累加器里面获得所有训练集的loss之和train_acc = accumulator[1]/accumulator[2] #从累加器里面获得所有训练集的准确数之和if (i+1) % (num_batches // 5) == 0 or i == num_batches-1:animator.add(epoch+(i+1)/num_batches,(train_loss,train_acc,None))test_accuracy = evaluate_accuracy_gpu(net,test_iter) #每次训练完一个epoch后的模型用于测试数据集上面计算测试精确度animator.add(epoch+1,(None,None,test_accuracy))print(f'模型训练完最后一轮时 train_loss:{train_loss},train_acc:{train_acc},test_acc:{test_accuracy}')print(f'{num_epochs*accumulator[2]/timer.sum()}examples/second on {str(device)}')#打印出模型每秒能处理多少个样本数lr,num_epochs= 0.9,10
train_ch6(net,train_iter=train_iter,test_iter=test_iter,lr=lr,num_epochs=num_epochs,device=d2l.try_gpu())
'''
输出结果:
模型训练完最后一轮时 train_loss:0.4322478462855021,train_acc:0.8396666666666667,test_acc:0.8163
55954.65804440994examples/second on cuda:0
'''#训练
if torch.cuda.is_available():device = "cuda:0"
else:device = "cpu"
device = torch.device(device)Common_functions.train_device(net,train_iter,test_iter,lr=0.9,device=device)
# %%plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/629009.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

工业平板定制方案_基于联发科、紫光展锐平台的工业平板电脑方案

工业平板主板采用联发科MT6762平台方案,搭载Android 11.0操作系统, 主频最高2.0GHz,效能有大幅提升;采用12nm先进工艺,具有低功耗高性能的特点。 该工业平板主板搭载了IMG GE8320图形处理器,最高主频为680MHz, 支持108…

Flume 之自定义Sink

1、简介 前文我们介绍了 Flume 如何自定义 Source, 并进行案例演示,本文将接着前文,自定义Sink,在这篇文章中,将使用自定义 Source 和 自定义的 Sink 实现数据传输,让大家快速掌握Flume这门技术。 2、自定…

Python - 深夜数据结构与算法之 Sort

目录 一.引言 二.排序简介 1.排序类型 2.时间复杂度 3.初级排序 4.高级排序 A.快速排序 B.归并排序 C.堆排序 5.特殊排序 三.经典算法实战 1.Quick-Sort 2.Merge-Sort 3.Heap-Sort 4.Relative-Sort-Array [1122] 5.Valid-anagram [242] 6.Merge-Intervals […

Java NIO (二)NIO Buffer类的重要方法(备份)

1 allocate()方法 在使用Buffer实例前,我们需要先获取Buffer子类的实例对象,并且分配内存空间。需要获取一个Buffer实例对象时,并不是使用子类的构造器来创建,而是调用子类的allocate()方法。 public class AllocateTest {static…

如何快速看懂一篇英文AI论文?

已经2024年了,该出现一个写论文解读AI Agent了。 大家肯定也在经常刷论文吧。 但真正尝试过用GPT去刷论文、写论文解读的小伙伴,一定深有体验——费劲。其他agents也没有能搞定的,今天我发现了一个超级厉害的写论文解读的agent &#xff0c…

某银行主机安全运营体系建设实践

随着商业银行业务的发展,主机规模持续增长,给安全团队运营工作带来极大挑战,传统的运营手段已经无法适应业务规模的快速发展,主要体现在主机资产数量多、类型复杂,安全团队难以对全量资产进行及时有效的梳理、管理&…

HCIA—— 16每日一讲:HTTP和HTTPS、无状态和cookie、持久连接和管线化、(初稿丢了,这是新稿,请宽恕我)

学习目标: HTTP和HTTPS、无状态和cookie、持久连接和管线化、HTTP的报文、URI和URL(初稿丢了,这是新稿,请宽恕我😶‍🌫️) 学习内容: HTTP无状态和cookieHTTPS持久连接和管线化 目…

vue2 pdfjs-2.8.335-dist pdf文件在线预览功能

1、首先先将 pdfjs-2.8.335-dist 文件夹从网上搜索下载,复制到public文件夹下. 2、在components下新建组件PdfViewer.vue文件 3、在el-upload 中调用 pdf-viewer 组件 4、在el-upload 中的 on-preview方法中加上对应的src路径 internalPreview(file) { //判断需要…

编译原理1.3习题 程序设计语言的发展历程

图源:文心一言 编译原理习题整理~🥝🥝 作为初学者的我,这些习题主要用于自我巩固。由于是自学,答案难免有误,非常欢迎各位小伙伴指正与讨论!👏💡 第1版:自…

IPv6隧道--GRE隧道

GRE隧道 通用路由封装协议GRE(Generic Routing Encapsulation)可以对某些网络层协议(如IPX、ATM、IPv6、AppleTalk等)的数据报文进行封装,使这些被封装的数据报文能够在另一个网络层协议(如IPv4)中传输。 GRE提供了将一种协议的报文封装在另一种协议报文中的机制,是一…

个人网站制作 Part 7 添加用户认证和数据库集成 | Web开发项目

文章目录 👩‍💻 基础Web开发练手项目系列:个人网站制作🚀 用户认证与数据库集成🔨添加用户认证🔧步骤 1: 使用Passport.js 🔨集成数据库🔧步骤 2: 使用MongoDB和Mongoose &#x1f…

Grafana(二)Grafana 两种数据源图表展示(json-api与数据库)

一. 背景介绍 在先前的博客文章中,我们搭建了Grafana ,它是一个开源的度量分析和可视化工具,可以通过将采集的数据分析、查询,然后进行可视化的展示,接下来我们重点介绍如何使用它来进行数据渲染图表展示 Docker安装G…

AIOps探索 | 基于大模型构建高效的运维知识及智能问答平台(2)

前面分享了平台对运维效率提升的重要性和挑战以及基于大模型的平台建设解决方案,新来的朋友点这里,一键回看精彩原文。 基于大模型构建高效的运维知识及智能问答平台(1)https://mp.csdn.net/mp_blog/creation/editor/135223109 …

【REMB 】翻译:草案remb-03

REMB REMB消息 以及 绝对时间戳选项 在带宽估计中的使用 :an absolute-value timestamp option for use in bandwidth estimatoin. 接收方带宽估计的RTCP消息 REMB 这位大神翻译的更好。 RTCP message for Receiver Estimated Maximum Bitrate draft-alvestrand-rmcat-remb-03…

iOS开发进阶(六):Xcode14 使用信号量造成线程优先级反转问题修复

文章目录 一、前言二、关于线程优先级反转三、优先级反转会造成什么后果四、怎么避免线程优先级反转五、使用信号量可能会造成线程优先级反转,且无法避免六、延伸阅读:iOS | Xcode中快速打开终端6.1 .sh绑定6.2 执行 pod install 脚本 七、延伸阅读&…

Android Activity的启动流程(Android-10)

前言 在Android开发中,我们经常会用到startActivity(Intent)方法,但是你知道startActivity(Intent)后Activity的启动流程吗?今天就专门讲一下最基础的startActivity(Intent)看一下Activity的启动流程,同时由于Launcher的启动后续…

STM32——DMA知识点及实战总结

1.DMA概念介绍 DMA,全称Direct Memory Access,即直接存储器访问。 DMA传输 将数据从一个地址空间复制到另一个地址空间。 注意:DMA传输无需CPU直接控制传输 2.DMA框图 3.DMA处理过程 外设的 8 个请求独立连接到每个通道,由 DMA_…

YOLOv5改进 | 融合改进篇 | 轻量化CCFM + SENetv2进行融合改进涨点 (全网独家首发)

一、本文介绍 本文给大家带来的改进机制是轻量化的Neck结构CCFM配合SENetv2改进的网络结构进行融合改进,其中CCFM为我本人根据RT-DETR模型一比一总结出来的,文中配其手撕结构图,其中SENetV2为网络结构重构化模块,通过其改进主干从而提取更有效的特征,这两个模块搭配在一起…

Java实现海南旅游景点推荐系统 JAVA+Vue+SpringBoot+MySQL

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 用户端2.2 管理员端 三、系统展示四、核心代码4.1 随机景点推荐4.2 景点评价4.3 协同推荐算法4.4 网站登录4.5 查询景点美食 五、免责说明 一、摘要 1.1 项目介绍 基于VueSpringBootMySQL的海南旅游推荐系统&#xff…

探索单元测试和 E2E 测试:提升软件质量的关键步骤(下)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…