动手学深度学习(Pytorch版)代码实践 -卷积神经网络-23卷积神经网络LeNet

23卷积神经网络LeNet

在这里插入图片描述

import torch
from torch import nn
import liliPytorch as lp
import matplotlib.pyplot as plt# 定义一个卷积神经网络
net = nn.Sequential(nn.Conv2d(1, 6,  kernel_size=5, padding=2), # 卷积层1:输入通道数1,输出通道数6,卷积核大小5x5,填充2nn.ReLU(), # 激活函数nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层1:池化窗口大小2x2,步幅2nn.Conv2d(6, 16, kernel_size=5), # 卷积层2:输入通道数6,输出通道数16,卷积核大小5x5nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层2:池化窗口大小2x2,步幅2nn.Flatten(), # 展平层:将多维输入展平为1维nn.Linear(16 * 5 * 5, 120), # 全连接层1:输入节点数16*5*5,输出节点数120nn.ReLU(),nn.Linear(120, 84), # 全连接层2:输入节点数120,输出节点数84nn.ReLU(), nn.Linear(84, 10) # 全连接层3:输入节点数84,输出节点数10(对应10个分类)
)# 通过在每一层打印输出的形状,我们可以检查模型
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32) # 随机生成一个形状为(1,1,28,28)的张量,作为输入
for layer in net:X = layer(X) # 将输入依次通过每一层print(layer.__class__.__name__, 'output shape: \t', X.shape) # 打印每一层的输出形状
"""
Conv2d output shape:     torch.Size([1, 6, 28, 28])
ReLU output shape:       torch.Size([1, 6, 28, 28])
AvgPool2d output shape:          torch.Size([1, 6, 14, 14])
Conv2d output shape:     torch.Size([1, 16, 10, 10])
ReLU output shape:       torch.Size([1, 16, 10, 10])
AvgPool2d output shape:          torch.Size([1, 16, 5, 5])
Flatten output shape:    torch.Size([1, 400])
Linear output shape:     torch.Size([1, 120])
ReLU output shape:       torch.Size([1, 120])
Linear output shape:     torch.Size([1, 84])
ReLU output shape:       torch.Size([1, 84])
Linear output shape:     torch.Size([1, 10])
"""
# 模型训练
batch_size = 256
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size) # 加载Fashion-MNIST数据集#分类精度
def accuracy(y_hat,y): #@save"""计算预测正确的数量"""#判断y_hat.shape是否为二维以上的矩阵#并且列数大于1if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:#axis = 1 表示按照每一行#argmax(axis = 1)得到每行最大值的下标y_hat = y_hat.argmax(axis = 1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def evaluate_accuracy_gpu(net, data_iter, device=None):"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval() # 将模型设置为评估模式metric = lp.Accumulator(2) # 正确预测数、预测总数with torch.no_grad(): # 禁用梯度计算for X, y in data_iter:if isinstance(X, list):X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(accuracy(net(X), y), y.numel()) # 累加正确预测数和样本总数return metric[0] / metric[1] # 返回精度def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight) # 初始化权重net.apply(init_weights) # 对网络应用权重初始化print('training on', device)net.to(device) # 将模型加载到设备上optimizer = torch.optim.SGD(net.parameters(), lr=lr) # 使用随机梯度下降优化器loss = nn.CrossEntropyLoss() # 定义交叉熵损失函数animator = lp.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc']) # 动画工具,绘制训练曲线timer, num_batches = lp.Timer(), len(train_iter) # 计时器和批次数for epoch in range(num_epochs):metric = lp.Accumulator(3) # 训练损失之和,训练准确率之和,样本数net.train() # 训练模式for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将数据加载到设备上y_hat = net(X) # 前向传播l = loss(y_hat, y) # 计算损失l.backward() # 反向传播optimizer.step() # 更新参数with torch.no_grad(): # 禁用梯度计算metric.add(l * X.shape[0], lp.accuracy(y_hat, y), X.shape[0]) # 累加损失、准确率和样本数timer.stop()train_l = metric[0] / metric[2] # 计算平均训练损失train_acc = metric[1] / metric[2] # 计算平均训练准确率if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None)) # 更新动画test_acc = evaluate_accuracy_gpu(net, test_iter, device) # 计算测试集上的准确率animator.add(epoch + 1, (None, None, test_acc)) # 更新动画print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')lr, num_epochs = 0.5, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu()) # 训练模型
# d2l.plt.show() # 显示训练曲线
plt.show() # 显示训练曲线# lr = 0.9,Sigmoid()
# loss 0.466, train acc 0.825, test acc 0.808# lr = 0.1,Sigmoid()
# loss 1.277, train acc 0.551, test acc 0.568# lr = 0.1,ReLU()
# loss 0.339, train acc 0.874, test acc 0.803# lr = 0.5,ReLU()
# loss 0.302, train acc 0.887, test acc 0.857# lr = 0.6,ReLU()
# loss 0.316, train acc 0.878, test acc 0.861

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/32476.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《分析模式》“鸦脚”表示法起源,Everest、Barker和Hay

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 《分析模式》这本书里面用的并不是UML表示法。作者Martin Fowler在书中也说了,该书写于1994-1995年,当时还没有UML。作者在书中用的是一种常被人称为“鸦脚”的…

独立看门狗窗口开门狗

独立看门狗 接线图:按键用于阻塞喂狗。独立看门狗&窗口开门狗接线一样。 第一步,是开启时钟了,只有这个LSI时钟开启了独立看门狗才能运行,所以初始化独立看门狗之前,LSI必须得开启,但是这个开启LSI的…

随记:内卷是什么意思?

内卷,网络流行语,原指一类文化模式达到了某种最终的形态以后,既没有办法稳定下来,也没有办法转变为新的形态,而只能不断地在内部变得更加复杂的现象。经网络流传,很多高等学校学生用其来指代非理性的内部竞…

视频格式怎么转换?9 个免费视频转换工具

前 9 款免费视频转换器有哪些?在此视频转换器评论中,我们收集了一些有用的提示并列出了顶级免费视频转换器软件,还找出了适合所有级别(从初学者到专家)的最佳免费视频转换器。 1. Geekersoft免费在线视频转换 最好的免…

kafka(二)安装部署(2)windows

一、前提 安装Kafka之前,需要安装JDK、Zookeeper、Scala, 本次安装版本选择: JDK:1.8 Zookeeper:3.6.4 Scala:2.12 Kafka:3.5.2 1、jdk Java Downloads | Oracle 见jdk下载安装。 2、Zookeeper 下载…

C# Winform中制作精美控件(2)

仓库温度监控系统重有个控件,就是温度监控,还是比较精美的,那么我们来看看制作的要点有哪些。 前面我们讨论过布局和圆角按钮。这节主要关注温度计控件 1. 布局: 两个Panel将界面分位上下两个部分,Dock.Top Dock.Fil…

关于小程序内嵌H5页面交互的问题?

有木有遇到?有木有遇到。 小程序内嵌了H5,然后H5某个按钮,需要打开小程序某个页面进行信息完善或登记,登记后要返回H5页面,而H5页面要动态显示刚才在小程序页面登记的信息。 操作流程是这样: 方案1&#…

编译原理期末复习

BUCT往年试题为导向的复习 标*的为往年真题 目录 1.基本概念 *例题(编译主要阶段) 编译程序与解释性程序区别 LL(1)概念 2.正则表达式转DFA (1)正则表达式转NFA 第一种方法(编程时常用) 第二种(考试时常用) &#xff08…

MK的前端精华笔记

文章目录 MK的前端精华笔记第一阶段:前端基础入门1、(1)、(2)、 2、3、4、5、6、7、 第二阶段:组件化与移动WebAPP开发1、(1)、(2)、 2、3、4、5、6、7、 第三…

【JavaEE】Spring Web MVC详解

一.基本概念. 1.什么是Spring Web MVC? 官方链接: https://docs.spring.io/spring-framework/reference/web/webmvc.html Spring Web MVC is the original web framework built on the Servlet API and has been included in the Spring Framework from the very beginning…

【ajax基础】回调函数地狱

一:什么是回调函数地狱 在一个回调函数中嵌套另一个回调函数(甚至一直嵌套下去),形成回调函数地狱 回调函数地狱存在问题: 可读性差异常捕获严重耦合性严重 // 1. 获取默认第一个省份的名字axios({url: http://hmaj…

5.什么是C语言

什么是 C 语言? C语言是一种用于和计算机交流的高级语言, 它既具有高级语言的特点,又具有汇编语言的特点 非常接近自然语言程序的执行效率非常高 C语言是所有编程语言中的经典,很多高级语言都是从C语言中衍生出来的, 例如:C、C#、Object-C、…

Android招聘市场技术要求越来越高,从事三年开发是否应该考虑转行?

UI这块知识是现今使用者最多的。当年火爆一时的Android入门培训,学会这小块知识就能随便找到不错的工作了。 不过很显然现在远远不够了,拒绝无休止的CV,亲自去项目实战,读源码,研究原理吧! 《Framework精编…

Unity 字体创建时候容易导致字体文件不正确的一种情况

上面得到了两种字体格式,一种是TextMeshPro的,另一种是Unity UI系统中默认使用的字体资源。其原因是创建的位置不同导致的。 1.下面是TextMeshPro字体创建的位置 2:下面是Unity UI系统中默认使用的字体资源

【FreeRTOS】任务状态改进播放控制

这里写目录标题 1 任务状态1.1 阻塞状态(Blocked)1.2 暂停状态(Suspended)1.3 就绪状态(Ready)1.4 完整的状态转换图 2 举个例子3 编写代码 参考《FreeRTOS入门与工程实践(基于DshanMCU-103).pdf》 本节课实现音乐任务的创建,音乐播放的暂停与继续播放,删…

算法竞赛创新实践总结

目录 1 算法题目................................... 3 1.1 盛最多水的容器.......................... 3 1.1.1 题目................................ 3 1.1.2 双指针.............................. 4 1.1.3 代码................................ 5 1.2 分巧克力...…

spring-依赖注入DI

Setter注入: 1、引用类型:在bean中定义引用类型属性并提供可访问的set方法,配置中使用property标签ref属性注入引用类型对象; 2、简单类型:在bean中定义引用类型属性并提供可访问的set方法,在配置中使用pr…

反馈时延与端到端拥塞控制

先从 越来越无效的拥塞控制 获得一个直感。 开局一张图,剩下全靠编。这是一道习题: 这图来自《高性能通信网络(第二版)》,2002 年的书,很好很高尚,目前这种书不多了。不准备做这道题,但意思要明白&#x…

Docker 拉取镜像失败处理 配置使用代理拉取

解决方案 1、在 /etc/systemd/system/docker.service.d/http-proxy.conf 配置文件中添加代理信息 2、重启docker服务 具体操作如下: 创建 dockerd 相关的 systemd 目录,这个目录下的配置将覆盖 dockerd 的默认配置 代码语言:javascript 复…

手撕RPC——前言

手撕RPC——前言 一、RPC是什么?二、为什么会出现RPC三、RPC的原理3.1 RPC是如何做到透明化远程服务调用?3.2 如何实现传输消息的编解码? 一、RPC是什么? RPC(Remote Procedure Call,远程过程调用&#xff…