动手学深度学习(Pytorch版)代码实践 -卷积神经网络-23卷积神经网络LeNet

23卷积神经网络LeNet

在这里插入图片描述

import torch
from torch import nn
import liliPytorch as lp
import matplotlib.pyplot as plt# 定义一个卷积神经网络
net = nn.Sequential(nn.Conv2d(1, 6,  kernel_size=5, padding=2), # 卷积层1:输入通道数1,输出通道数6,卷积核大小5x5,填充2nn.ReLU(), # 激活函数nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层1:池化窗口大小2x2,步幅2nn.Conv2d(6, 16, kernel_size=5), # 卷积层2:输入通道数6,输出通道数16,卷积核大小5x5nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层2:池化窗口大小2x2,步幅2nn.Flatten(), # 展平层:将多维输入展平为1维nn.Linear(16 * 5 * 5, 120), # 全连接层1:输入节点数16*5*5,输出节点数120nn.ReLU(),nn.Linear(120, 84), # 全连接层2:输入节点数120,输出节点数84nn.ReLU(), nn.Linear(84, 10) # 全连接层3:输入节点数84,输出节点数10(对应10个分类)
)# 通过在每一层打印输出的形状,我们可以检查模型
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32) # 随机生成一个形状为(1,1,28,28)的张量,作为输入
for layer in net:X = layer(X) # 将输入依次通过每一层print(layer.__class__.__name__, 'output shape: \t', X.shape) # 打印每一层的输出形状
"""
Conv2d output shape:     torch.Size([1, 6, 28, 28])
ReLU output shape:       torch.Size([1, 6, 28, 28])
AvgPool2d output shape:          torch.Size([1, 6, 14, 14])
Conv2d output shape:     torch.Size([1, 16, 10, 10])
ReLU output shape:       torch.Size([1, 16, 10, 10])
AvgPool2d output shape:          torch.Size([1, 16, 5, 5])
Flatten output shape:    torch.Size([1, 400])
Linear output shape:     torch.Size([1, 120])
ReLU output shape:       torch.Size([1, 120])
Linear output shape:     torch.Size([1, 84])
ReLU output shape:       torch.Size([1, 84])
Linear output shape:     torch.Size([1, 10])
"""
# 模型训练
batch_size = 256
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size) # 加载Fashion-MNIST数据集#分类精度
def accuracy(y_hat,y): #@save"""计算预测正确的数量"""#判断y_hat.shape是否为二维以上的矩阵#并且列数大于1if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:#axis = 1 表示按照每一行#argmax(axis = 1)得到每行最大值的下标y_hat = y_hat.argmax(axis = 1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def evaluate_accuracy_gpu(net, data_iter, device=None):"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval() # 将模型设置为评估模式metric = lp.Accumulator(2) # 正确预测数、预测总数with torch.no_grad(): # 禁用梯度计算for X, y in data_iter:if isinstance(X, list):X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(accuracy(net(X), y), y.numel()) # 累加正确预测数和样本总数return metric[0] / metric[1] # 返回精度def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight) # 初始化权重net.apply(init_weights) # 对网络应用权重初始化print('training on', device)net.to(device) # 将模型加载到设备上optimizer = torch.optim.SGD(net.parameters(), lr=lr) # 使用随机梯度下降优化器loss = nn.CrossEntropyLoss() # 定义交叉熵损失函数animator = lp.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc']) # 动画工具,绘制训练曲线timer, num_batches = lp.Timer(), len(train_iter) # 计时器和批次数for epoch in range(num_epochs):metric = lp.Accumulator(3) # 训练损失之和,训练准确率之和,样本数net.train() # 训练模式for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将数据加载到设备上y_hat = net(X) # 前向传播l = loss(y_hat, y) # 计算损失l.backward() # 反向传播optimizer.step() # 更新参数with torch.no_grad(): # 禁用梯度计算metric.add(l * X.shape[0], lp.accuracy(y_hat, y), X.shape[0]) # 累加损失、准确率和样本数timer.stop()train_l = metric[0] / metric[2] # 计算平均训练损失train_acc = metric[1] / metric[2] # 计算平均训练准确率if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None)) # 更新动画test_acc = evaluate_accuracy_gpu(net, test_iter, device) # 计算测试集上的准确率animator.add(epoch + 1, (None, None, test_acc)) # 更新动画print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')lr, num_epochs = 0.5, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu()) # 训练模型
# d2l.plt.show() # 显示训练曲线
plt.show() # 显示训练曲线# lr = 0.9,Sigmoid()
# loss 0.466, train acc 0.825, test acc 0.808# lr = 0.1,Sigmoid()
# loss 1.277, train acc 0.551, test acc 0.568# lr = 0.1,ReLU()
# loss 0.339, train acc 0.874, test acc 0.803# lr = 0.5,ReLU()
# loss 0.302, train acc 0.887, test acc 0.857# lr = 0.6,ReLU()
# loss 0.316, train acc 0.878, test acc 0.861

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/32476.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《分析模式》“鸦脚”表示法起源,Everest、Barker和Hay

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 《分析模式》这本书里面用的并不是UML表示法。作者Martin Fowler在书中也说了,该书写于1994-1995年,当时还没有UML。作者在书中用的是一种常被人称为“鸦脚”的…

独立看门狗窗口开门狗

独立看门狗 接线图:按键用于阻塞喂狗。独立看门狗&窗口开门狗接线一样。 第一步,是开启时钟了,只有这个LSI时钟开启了独立看门狗才能运行,所以初始化独立看门狗之前,LSI必须得开启,但是这个开启LSI的…

Activity生命周期:深入解析与面试准备

在Android开发中,Activity的生命周期是一个至关重要的概念。它不仅关系到应用的性能和用户体验,也是面试中常被提及的技术点。以下将从技术难点、面试官关注点、回答吸引力以及代码举例四个方面,详细阐述Activity的生命周期及其各个回调方法的…

随记:内卷是什么意思?

内卷,网络流行语,原指一类文化模式达到了某种最终的形态以后,既没有办法稳定下来,也没有办法转变为新的形态,而只能不断地在内部变得更加复杂的现象。经网络流传,很多高等学校学生用其来指代非理性的内部竞…

视频格式怎么转换?9 个免费视频转换工具

前 9 款免费视频转换器有哪些?在此视频转换器评论中,我们收集了一些有用的提示并列出了顶级免费视频转换器软件,还找出了适合所有级别(从初学者到专家)的最佳免费视频转换器。 1. Geekersoft免费在线视频转换 最好的免…

kafka(二)安装部署(2)windows

一、前提 安装Kafka之前,需要安装JDK、Zookeeper、Scala, 本次安装版本选择: JDK:1.8 Zookeeper:3.6.4 Scala:2.12 Kafka:3.5.2 1、jdk Java Downloads | Oracle 见jdk下载安装。 2、Zookeeper 下载…

C# Winform中制作精美控件(2)

仓库温度监控系统重有个控件,就是温度监控,还是比较精美的,那么我们来看看制作的要点有哪些。 前面我们讨论过布局和圆角按钮。这节主要关注温度计控件 1. 布局: 两个Panel将界面分位上下两个部分,Dock.Top Dock.Fil…

关于小程序内嵌H5页面交互的问题?

有木有遇到?有木有遇到。 小程序内嵌了H5,然后H5某个按钮,需要打开小程序某个页面进行信息完善或登记,登记后要返回H5页面,而H5页面要动态显示刚才在小程序页面登记的信息。 操作流程是这样: 方案1&#…

编译原理期末复习

BUCT往年试题为导向的复习 标*的为往年真题 目录 1.基本概念 *例题(编译主要阶段) 编译程序与解释性程序区别 LL(1)概念 2.正则表达式转DFA (1)正则表达式转NFA 第一种方法(编程时常用) 第二种(考试时常用) &#xff08…

MK的前端精华笔记

文章目录 MK的前端精华笔记第一阶段:前端基础入门1、(1)、(2)、 2、3、4、5、6、7、 第二阶段:组件化与移动WebAPP开发1、(1)、(2)、 2、3、4、5、6、7、 第三…

【JavaEE】Spring Web MVC详解

一.基本概念. 1.什么是Spring Web MVC? 官方链接: https://docs.spring.io/spring-framework/reference/web/webmvc.html Spring Web MVC is the original web framework built on the Servlet API and has been included in the Spring Framework from the very beginning…

【linux】centos yum 换源

一、问题描述 CentOS 7 更换yum源后,无法正常使用,报错信息如下: [roothost-10-43-1-3 ~]# yum install tmux Loaded plugins: fastestmirror, langpacks Repository base is listed more than once in the configuration Repository updat…

代码随想录算法训练营day23|669.修剪二叉搜索树、108.将有序数组转换为二叉搜索树、538.把二叉搜索树转换为累加树

669.修剪二叉搜索树 这道题目需要考虑当前节点是否在[low,high]之间, 因为是平衡二叉树, 所以当当前节点值小于low时,那么其左节点肯定更小,因此删除该节点的方式是给root节点返回其右节点的递归,注意:这里…

爬虫笔记13——网页爬取数据写入MySQL数据库,以阿里recruit为例

下载pymysql库 需要下载pymysql库,以便于在程序中连接MySQL数据库 pip install pymysql # 或者使用国内的镜像文件,安装会快一点 pip install pymysql -i https://pypi.douban.com/simple需要安装MySQL,并创建使用数据库 安装MySQL可以看这…

【ajax基础】回调函数地狱

一:什么是回调函数地狱 在一个回调函数中嵌套另一个回调函数(甚至一直嵌套下去),形成回调函数地狱 回调函数地狱存在问题: 可读性差异常捕获严重耦合性严重 // 1. 获取默认第一个省份的名字axios({url: http://hmaj…

【流星蝴蝶剑game】

由于《流星蝴蝶剑》是一款较旧的游戏,而且我无法提供受版权保护的游戏的代码,我将提供一个简单的2D游戏编程实例,以展示如何使用Unity引擎和C#语言来创建一个基本的游戏。这个例子将涉及到创建一个玩家角色,使其能够移动并收集物品…

5.什么是C语言

什么是 C 语言? C语言是一种用于和计算机交流的高级语言, 它既具有高级语言的特点,又具有汇编语言的特点 非常接近自然语言程序的执行效率非常高 C语言是所有编程语言中的经典,很多高级语言都是从C语言中衍生出来的, 例如:C、C#、Object-C、…

Android招聘市场技术要求越来越高,从事三年开发是否应该考虑转行?

UI这块知识是现今使用者最多的。当年火爆一时的Android入门培训,学会这小块知识就能随便找到不错的工作了。 不过很显然现在远远不够了,拒绝无休止的CV,亲自去项目实战,读源码,研究原理吧! 《Framework精编…

Unity 字体创建时候容易导致字体文件不正确的一种情况

上面得到了两种字体格式,一种是TextMeshPro的,另一种是Unity UI系统中默认使用的字体资源。其原因是创建的位置不同导致的。 1.下面是TextMeshPro字体创建的位置 2:下面是Unity UI系统中默认使用的字体资源

【FreeRTOS】任务状态改进播放控制

这里写目录标题 1 任务状态1.1 阻塞状态(Blocked)1.2 暂停状态(Suspended)1.3 就绪状态(Ready)1.4 完整的状态转换图 2 举个例子3 编写代码 参考《FreeRTOS入门与工程实践(基于DshanMCU-103).pdf》 本节课实现音乐任务的创建,音乐播放的暂停与继续播放,删…