LeNet-5(fashion-mnist)

文章目录

  • 前言
  • LeNet
  • 模型训练

前言

LeNet是最早发布的卷积神经网络之一。该模型被提出用于识别图像中的手写数字。

LeNet

LeNet-5由以下两个部分组成

  • 卷积编码器(2)
  • 全连接层(3)
    卷积块由一个卷积层、一个sigmoid激活函数和一个平均汇聚层组成。
    第一个卷积层有6个输出通道,第二个卷积层有16个输出通道。采用2×2的汇聚操作,且步幅为2.
    3个全连接层分别有120,84,10个输出。
    此处对原始模型做出部分修改,去除最后一层的高斯激活。
net=nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(16*5*5,120),nn.Sigmoid(),nn.Linear(120,84),nn.Sigmoid(),nn.Linear(84,10))

模型训练

为了加快训练,使用GPU计算测试集上的精度以及训练过程中的计算。
此处采用xavier初始化模型参数以及交叉熵损失函数和小批量梯度下降。

batch_size=256
train_iter,test_iter=data_iter.load_data_fashion_mnist(batch_size)

将数据送入GPU进行计算测试集准确率

def evaluate_accuracy_gpu(net,data_iter,device=None):"""使用GPU计算模型在数据集上的精度"""if isinstance(net,torch.nn.Module):net.eval()if not device:device=next(iter(net.parameters())).device# 正确预测的数量,预测的总数eva = 0.0y_num = 0.0with torch.no_grad():for X,y in data_iter:if isinstance(X,list):X=[x.to(device) for x in X]else:X=X.to(device)y=y.to(device)eva += accuracy(net(X), y)y_num += y.numel()return eva/y_num

训练过程同样将数据送入GPU计算

def train_epoch_gpu(net, train_iter, loss, updater,device):# 训练损失之和,训练准确数之和,样本数train_loss_sum = 0.0train_acc_sum = 0.0num_samples = 0.0# timer = d2l.torch.Timer()for i, (X, y) in enumerate(train_iter):# timer.start()updater.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()updater.step()with torch.no_grad():train_loss_sum += l * X.shape[0]train_acc_sum += evaluation.accuracy(y_hat, y)num_samples += X.shape[0]# timer.stop()return train_loss_sum/num_samples,train_acc_sum/num_samplesdef train_gpu(net,train_iter,test_iter,num_epochs,lr,device):def init_weights(m):if type(m)==torch.nn.Linear or type(m)==torch.nn.Conv2d:torch.nn.init.xavier_uniform_(m.weight)net.apply(init_weights)net.to(device)print('training on',device)optimizer=torch.optim.SGD(net.parameters(),lr=lr)loss=torch.nn.CrossEntropyLoss()# num_batches=len(train_iter)tr_l=[]tr_a=[]te_a=[]for epoch in range(num_epochs):net.train()train_metric=train_epoch_gpu(net,train_iter,loss,optimizer,device)test_accuracy = evaluation.evaluate_accuracy_gpu(net, test_iter)train_loss, train_acc = train_metrictrain_loss = train_loss.cpu().detach().numpy()tr_l.append(train_loss)tr_a.append(train_acc)te_a.append(test_accuracy)print(f'epoch: {epoch + 1}, train_loss: {train_loss}, train_acc: {train_acc}, test_acc:{test_accuracy}')x = torch.arange(num_epochs)plt.plot((x + 1), tr_l, '-', label='train_loss')plt.plot(x + 1, tr_a, '--', label='train_acc')plt.plot(x + 1, te_a, '-.', label='test_acc')plt.legend()plt.show()print(f'on {str(device)}')
lr,num_epochs=0.9,10
Train.train_gpu(net,train_iter,test_iter,num_epochs,lr,device='cuda')

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/612384.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

了解结构体以及结构体数组

C语言的结构体你真的了解吗? 一起来看一下吧!!! 1.结构体是啥? 结构体是多种数据类型的组合体 2.格式(一般放在主函数前,也就是int main()前面 ) 关键字 结构体名字 {成员列表…

python代码练习:双指针法

题目一:移除元素 给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素,并返回移除后数组的新长度。 不要使用额外的数组空间,你必须仅使用 O(1) 额外空间并 原地 修改输入数组。 元素的顺序可以改变。你不…

「HDLBits题解」Module

本专栏的目的是分享可以通过HDLBits仿真的Verilog代码 以提供参考 各位可同时参考我的代码和官方题解代码 或许会有所收益 题目链接:Module - HDLBits module top_module ( input a, input b, output out );mod_a u1 (.in1(a), .in2(b), .out(out)) ; endmodule

[足式机器人]Part3 机构运动学与动力学分析与建模 Ch00-2(1) 质量刚体的在坐标系下运动

本文仅供学习使用,总结很多本现有讲述运动学或动力学书籍后的总结,从矢量的角度进行分析,方法比较传统,但更易理解,并且现有的看似抽象方法,两者本质上并无不同。 2024年底本人学位论文发表后方可摘抄 若有…

【论文阅读】Deep Graph Infomax

目录 0、基本信息1、研究动机2、创新点2.1、核心思想:2.2、思想推导: 3、准备3.1、符号3.2、互信息3.3、JS散度3.4、Deep InfoMax方法3.5、判别器:f-GAN估计散度 4、具体实现4.1、局部-全局互信息最大化4.2、理论动机 5、实验设置5.1、直推式…

[Linux]网卡配置修改

CentOS修改网卡配置 Ubuntu 修改配置文件方式 使用vim命令,打开/etc/resolv.conf文件,按i开始编辑 #DNS #代表注释 nameserver 127.0.0.53ESC之后:wq保存退出 重启网络服务 service networking restartnetplan方式 Ubuntu 18.04开始,可…

C# 使用Fleck创建WebSocket服务器

目录 写在前面 代码实现 服务端代码 客户端代码 调用示例 写在前面 Fleck 是 C# 实现的 WebSocket 服务器,通过 WebSocket API,浏览器和服务器只需要做一个握手的动作,然后浏览器和服务器之间就形成了一条快速通道;两者之间…

1.5 Unity中的数据存储 PlayerPrefs

Unity中的三种数据存储:数据存储也称为数据持久化 一、PlayerPrefs PlayerPrefs是Unity引擎自身提供的一个用于本地持久化保存与读取的类,以键值对的形式将数据保存在文件中,然后程序可以根据关键字提取数值。 PlayerPrefs类支持3种数据类…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《计及储能参与的电能-调频-备用市场日前联合交易决策模型》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 这个标题涉及到电能、调频和备用市场的联合交易决策模型,并特别考虑了储能在其中的参与。 电能市场: 这是指电能的买卖市场&…

What does `rpm -e` do?

卸载 rpm包 rpm -e php 卸载匹配的所有rpm包rpm -e $(rpm -qa php*) 卸载匹配的所有rpm包[Ref] Erase multiple packages using rpm or yum Further Reading :Linux rpm 命令

Java使用IText生产PDF时,中文标点符号出现在行首的问题处理

Java使用IText生成PDF时,中文标点符号出现在行首的问题处理 使用itext 5进行html转成pdf时,标点符号出现在某一行的开头 但这种情况下显然不符合中文书写的规则,主要问题出在itext中的DefaultSplitCharacter类,该方法主要用来判断…

SpringMVC-02

添加EnableWebMvc //配置json转化器 (使用postman) 可以不用写下面两个方法了 Bean public RequestMappingHandlerMapping handlerMapping(){ return new RequestMappingHandlerMapping(); } Bean public RequestMappingHandlerAdapter handlerAdapter()…

04- OpenCV:Mat对象简介和使用

目录 1、Mat对象与IplImage对象 2、Mat对象使用 3、Mat定义数组 4、相关的代码演示 1、Mat对象与IplImage对象 先看看Mat对象:图片在计算机眼里都是一个二维数组; 在OpenCV中,Mat是一个非常重要的类,用于表示图像或矩阵数据。…

⭐Unity 将电脑打开的窗口画面显示在程序中

1.效果: 下载资源包地址: Unity中获取桌面窗口 2.下载uWindowCapturev1.1.2.unitypackage 放入Unity工程 3.打开Single Window场景,将组件UwcWindowTexture的PartialWindowTitle进行修改,我以腾讯会议为例 感谢大家的观看&#xf…

CSS3实现轮播效果

在我们不使用JS的情况下&#xff0c;是否也可以实现轮播功能呢&#xff1f; 答应是可以的 上代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>轮播</title><style>.boss…

激活函数整理

sigmoid函数 import torch from d2l import torch as d2l %matplotlib inline ​ xtorch.arange(-10,10,0.1,requires_gradTrue) sigmoidtorch.nn.Sigmoid() ysigmoid(x) ​ d2l.plot(x.detach(),y.detach(),x,sigmoid(x),figsize(5,2.5)) sigmoid函数连续、光滑、单调递增&am…

Java language programming:判断整数n是否在数组中存在

题目&#xff1a;已知给定一个整数数组&#xff0c;输入一个整数n&#xff0c;那么如果该整数n存在于这个数组中&#xff0c;则需要输出下标&#xff1b;如果不存在的话&#xff0c;则需要输出-1。 输入格式&#xff1a; 1 9 输出格式&#xff1a; 0 -1 import java.util.*; …

python爬取诗词名句网-三国演义,涉及知识点:xpath,requests,自动识别编码,range

页面源代码: <!DOCTYPE html> <html lang="zh"> <head><script src="https://img.shicimingju.com/newpage/js/all.js"></script><meta charset="UTF-8"><title>《三国演义》全集在线阅读_史书典籍_…

mysql 分组函数,分组查询

#1.分组函数 功能&#xff1a;用作统计使用&#xff0c;又称聚合函数&#xff0c;统计函数&#xff0c;组函数 分类: sum :求和&#xff0c;avg 平均值&#xff0c;max最大值&#xff0c;min最小值&#xff0c;count计算个数 特点: sum, avg 一般用于处理数值型 max ,min ,coun…

metartc5_jz源码阅读-udp->receive

之前在metartc5_jz源码阅读-yang_run_rtcudp_thread-CSDN博客中说到&#xff1a; //调用udp的receive方法将读取的buffer和udp->user传入。 if (udp->receive) udp->receive(buffer, len, udp->user); 这个函数在以下代码中已经设置执行函数&#xff1a; sessio…