23/76-LeNet

LeNet
早期成功的神经网络。
先使用卷积层来学习图片空间信息。
然后使用全连接层转换到类别空间。

在这里插入图片描述

#In[]
'''
LeNet,上世纪80年代的产物,最初为了手写识别设计
'''
from d2l import torch as d2l
import torch 
from torch import nn
from torch.nn.modules.loss import CrossEntropyLossfrom torch.utils import data
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import Common_functions'''
LeNet:
两个卷积层,两个池化层,三个线性层
假定为MNIST设计,输入为(batch_size,1,28,28)
'''class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)net = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5,5),padding=2),nn.Sigmoid(), #输出:(6,28,28)nn.AvgPool2d(kernel_size=(2,2)), #不指定stride默认不重叠 输出(6,14,14)nn.Conv2d(6,16,kernel_size=(5,5)),nn.Sigmoid(),#输出(16,10,10)nn.AvgPool2d(kernel_size=(2,2)),#输出(16,5,5)nn.Flatten(),nn.Linear(16*5*5,120),nn.Sigmoid(),#nn.Linear(120,84),nn.Sigmoid(),nn.Linear(84,10)
)X=torch.rand(size=(1,1,28,28),dtype=torch.float32)
for layer in net:X=layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)#In[]batch_size = 256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size=batch_size)#对evaluate_accuracy函数进行轻微修改
#使用GPU计算模型在数据集上的精度
#计算网络在测试数据集上面的准确率
#由于完整的测试数据集位于内存中,因此在模型使用GPU预测测试数据集之前,我们需要将其复制到显存中。
def evaluate_accuracy_gpu(net,data_iter,device=None):if isinstance(net,nn.Module):net.eval() #网络用于测试数据if not device:device = next(iter(net.parameters())).device #如果没有指定device设备,device设备则使用第一层网络参数的设备accumulator = d2l.Accumulator(2) #累加器里面包含两个元素for X,y in data_iter:if isinstance(X,list):X = [x.to(device) for x in X] #X为list类型时,需要加X里面每个元素都复制到device设备上面来else:X = X.to(device)y = y.to(device)accumulator.add(d2l.accuracy(net(X),y),y.numel()) #累加器第一个元素为在每一个batch_size中预测准确的个数,第二个元素为每一个batch_size中样本总数目,然后依次循环累加,得到测试数据集上面预测准确的总数目,以及数据集总数目return accumulator[0]/accumulator[1] #算出模型预测准确率def train_ch6(net,train_iter,test_iter,num_epochs,lr,device):def init_weights(m):#手动初始化模型参数if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight) #使用xavier_uniform分布初始化参数net.apply(init_weights)net.to(device)#将模型复制到gpu上面print('training on',device)loss = nn.CrossEntropyLoss() #定义lossoptim = torch.optim.SGD(net.parameters(),lr=lr) #定义优化器animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=['train_loss','train_acc','test_acc'])timer = d2l.Timer()num_batches = len(train_iter)for epoch in range(num_epochs):net.train()#模型开始训练,需要放在第一层循环里面,因为后面evaluate_accuracy_gpu()函数里面有net.eval(),将模型改变为测试状态,因此需要在每一个循环epoch后面手动再加上模型开始处于训练状态accumulator = d2l.Accumulator(3) #累加器for i,(X,y) in enumerate(train_iter):timer.start()optim.zero_grad()X = X.to(device)#将X复制到gpu上面y = y.to(device) #将y复制到gpu上面y_hat = net(X) #得到模型训练后的输出标签y_hatl = loss(y_hat,y)#计算每一个batch_size的lossl.backward() #计算梯度optim.step() #使用优化器更新模型参数with torch.no_grad():#不需要模型梯度accumulator.add(l*X.shape[0],d2l.accuracy(y_hat,y),X.shape[0])timer.stop()train_loss = accumulator[0]/accumulator[2] #从累加器里面获得所有训练集的loss之和train_acc = accumulator[1]/accumulator[2] #从累加器里面获得所有训练集的准确数之和if (i+1) % (num_batches // 5) == 0 or i == num_batches-1:animator.add(epoch+(i+1)/num_batches,(train_loss,train_acc,None))test_accuracy = evaluate_accuracy_gpu(net,test_iter) #每次训练完一个epoch后的模型用于测试数据集上面计算测试精确度animator.add(epoch+1,(None,None,test_accuracy))print(f'模型训练完最后一轮时 train_loss:{train_loss},train_acc:{train_acc},test_acc:{test_accuracy}')print(f'{num_epochs*accumulator[2]/timer.sum()}examples/second on {str(device)}')#打印出模型每秒能处理多少个样本数lr,num_epochs= 0.9,10
train_ch6(net,train_iter=train_iter,test_iter=test_iter,lr=lr,num_epochs=num_epochs,device=d2l.try_gpu())
'''
输出结果:
模型训练完最后一轮时 train_loss:0.4322478462855021,train_acc:0.8396666666666667,test_acc:0.8163
55954.65804440994examples/second on cuda:0
'''#训练
if torch.cuda.is_available():device = "cuda:0"
else:device = "cpu"
device = torch.device(device)Common_functions.train_device(net,train_iter,test_iter,lr=0.9,device=device)
# %%plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/629009.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

无法加载操作系统,原因是关键系统驱动程序丢失或包含错误

bcdboot c:\windows /l zh-cn 用这个命令解决了,没有进入时候蓝屏了,不知道为什么 问题 无法加载操作系统,原因是关键系统驱动程序丢失或包含错误上午因为有点事就没有像往常一样打开电脑,下午回到家休息了一会本来准备打开电脑开始我愉快地下午生活,没想到一个自动恢复给…

工业平板定制方案_基于联发科、紫光展锐平台的工业平板电脑方案

工业平板主板采用联发科MT6762平台方案,搭载Android 11.0操作系统, 主频最高2.0GHz,效能有大幅提升;采用12nm先进工艺,具有低功耗高性能的特点。 该工业平板主板搭载了IMG GE8320图形处理器,最高主频为680MHz, 支持108…

Java设计模式之访问者模式详解

Java设计模式之访问者模式详解 大家好,我是免费搭建查券返利机器人赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿!今天,让我们一同踏上Java设计模式之旅,探索访问者模式&#x…

Flume 之自定义Sink

1、简介 前文我们介绍了 Flume 如何自定义 Source, 并进行案例演示,本文将接着前文,自定义Sink,在这篇文章中,将使用自定义 Source 和 自定义的 Sink 实现数据传输,让大家快速掌握Flume这门技术。 2、自定…

JVM与HotSpot

JVM和HotSpot 1、概念 JVM是虚拟机的规范,HotSpot是jvm的具体实现 HotSpot包括一个解释器和两个编译器(client 和 server,二选一的),解释与编译混合执行模式,默认启动解释执行。 编译器:java源…

121_买卖股票的最佳时机

描述 给定一个数组 prices ,它的第 i 个元素 prices[i] 表示一支给定股票第 i 天的价格。 你只能选择某一天买入这只股票,并选择在未来的某一个不同的日子卖出该股票。设计一个算法来计算你所能获取的最大利润。 返回你可以从这笔交易中获取的最大利润。…

Python - 深夜数据结构与算法之 Sort

目录 一.引言 二.排序简介 1.排序类型 2.时间复杂度 3.初级排序 4.高级排序 A.快速排序 B.归并排序 C.堆排序 5.特殊排序 三.经典算法实战 1.Quick-Sort 2.Merge-Sort 3.Heap-Sort 4.Relative-Sort-Array [1122] 5.Valid-anagram [242] 6.Merge-Intervals […

Java NIO (二)NIO Buffer类的重要方法(备份)

1 allocate()方法 在使用Buffer实例前,我们需要先获取Buffer子类的实例对象,并且分配内存空间。需要获取一个Buffer实例对象时,并不是使用子类的构造器来创建,而是调用子类的allocate()方法。 public class AllocateTest {static…

如何快速看懂一篇英文AI论文?

已经2024年了,该出现一个写论文解读AI Agent了。 大家肯定也在经常刷论文吧。 但真正尝试过用GPT去刷论文、写论文解读的小伙伴,一定深有体验——费劲。其他agents也没有能搞定的,今天我发现了一个超级厉害的写论文解读的agent &#xff0c…

某银行主机安全运营体系建设实践

随着商业银行业务的发展,主机规模持续增长,给安全团队运营工作带来极大挑战,传统的运营手段已经无法适应业务规模的快速发展,主要体现在主机资产数量多、类型复杂,安全团队难以对全量资产进行及时有效的梳理、管理&…

JS中数组的相关方法介绍

push() 将一个或多个元素添加到数组的末尾,并返回新的长度。 let arr [1, 2, 3]; arr.push(4); // arr 现在是 [1, 2, 3, 4] pop() 删除并返回数组的最后一个元素 let arr [1, 2, 3, 4]; let last arr.pop(); // last 现在是 4,arr 现在是 [1, …

第23章 集 ,势(阿列夫0),良序集(序数),有理数无理数

继续讲解集,接下来讲集的运算,集合的交和并,上开口是交集下开口是并集,这里有一些类似于加法和乘法的样子,其实也没有错,乘法符号也只是一个符号,真正有用的是表示的交换和结合率 集这个概念&a…

HCIA—— 16每日一讲:HTTP和HTTPS、无状态和cookie、持久连接和管线化、(初稿丢了,这是新稿,请宽恕我)

学习目标: HTTP和HTTPS、无状态和cookie、持久连接和管线化、HTTP的报文、URI和URL(初稿丢了,这是新稿,请宽恕我😶‍🌫️) 学习内容: HTTP无状态和cookieHTTPS持久连接和管线化 目…

深入解析iOS中的layoutSubviews方法

深入解析iOS中的layoutSubviews方法 大家好,我是免费搭建查券返利机器人赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿!在今天的文章中,我们将深入研究iOS开发中一个不可或缺的方法——lay…

vue2 pdfjs-2.8.335-dist pdf文件在线预览功能

1、首先先将 pdfjs-2.8.335-dist 文件夹从网上搜索下载,复制到public文件夹下. 2、在components下新建组件PdfViewer.vue文件 3、在el-upload 中调用 pdf-viewer 组件 4、在el-upload 中的 on-preview方法中加上对应的src路径 internalPreview(file) { //判断需要…

编译原理1.3习题 程序设计语言的发展历程

图源:文心一言 编译原理习题整理~🥝🥝 作为初学者的我,这些习题主要用于自我巩固。由于是自学,答案难免有误,非常欢迎各位小伙伴指正与讨论!👏💡 第1版:自…

go语言GMP模式介绍以及协程案例展示

一. MPG模式 Go语言的调度模型被称为GMP,这是一个高效且复杂的调度系统,用于在可用的物理线程上调度goroutines(Go的轻量级线程)。GMP模型由三个主要组件构成:Goroutine、M(机器)和P&#xff0…

IPv6隧道--GRE隧道

GRE隧道 通用路由封装协议GRE(Generic Routing Encapsulation)可以对某些网络层协议(如IPX、ATM、IPv6、AppleTalk等)的数据报文进行封装,使这些被封装的数据报文能够在另一个网络层协议(如IPv4)中传输。 GRE提供了将一种协议的报文封装在另一种协议报文中的机制,是一…

Linux一条命令换阿里源

要在Linux系统中切换到阿里源,可以使用以下命令。请注意,不同的Linux发行版可能有不同的包管理工具,因此命令可能会有所不同。 对于使用apt的Debian/Ubuntu系统: sudo cp /etc/apt/sources.list /etc/apt/sources.list.backup …

个人网站制作 Part 7 添加用户认证和数据库集成 | Web开发项目

文章目录 👩‍💻 基础Web开发练手项目系列:个人网站制作🚀 用户认证与数据库集成🔨添加用户认证🔧步骤 1: 使用Passport.js 🔨集成数据库🔧步骤 2: 使用MongoDB和Mongoose &#x1f…