机器学习深度学习——卷积神经网络(LeNet)

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——池化层
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

卷积神经网络(LeNet)

  • 引言
  • LeNet
  • 模型训练
  • 小结

引言

之前的内容中曾经将softmax回归模型和多层感知机应用于Fashion-MNIST数据集中的服装图片。为了能应用他们,我们首先就把图像展平成了一维向量,然后用全连接层对其进行处理。
而现在已经学习过了卷积层的处理方法,我们就可以在图像中保留空间结构。同时,用卷积层代替全连接层的另一个好处是:模型更简单,所需参数更少。
LeNet是最早发布的卷积神经网络之一,之前出来的目的是为了识别图像中的手写数字。

LeNet

总体看,由两个部分组成:
1、卷积编码器:由两个卷积层组成
2、全连接层密集快:由三个全连接层组成
在这里插入图片描述
上图中就是LeNet的数据流图示,其中汇聚层也就是池化层。
最终输出的大小是10,也就是10个可能结果(0-9)。
每个卷积块的基本单元是一个卷积层、一个sigmoid激活函数和平均池化层(当年没有ReLU和最大池化层)。每个卷积层使用5×5卷积核和一个sigmoid激活函数。
这些层的作用就是将输入映射到多个二维特征输出,通常同时增加通道的数量。(从上图容易看出:第一卷积层有6个输出通道,而第二个卷积层有16个输出通道;每个2×2池操作(步幅也为2)通过空间下采样将维数减少4倍)。卷积的输出形状那是由批量大小、通道数、高度、宽度决定。
为了将卷积块的输出传递给稠密块,我们必须在小批量中展平每个样本(也就是把四维的输入转换为全连接层期望的二维输入,第一维索引小批量中的样本,第二维给出给个样本的平面向量表示)。
LeNet的稠密块有三个全连接层,分别有120、84和10个输出。因为我们在执行分类任务,所以输出层的10维对应于最后输出结果的数量(代表0-9是个结果)。
深度学习框架实现此类模型非常简单,用一个Sequential块把需要的层连接在一个就可以了,我们对原始模型做一个小改动,去掉最后一层的高斯激活:

import torch
from torch import nn
from d2l import torch as d2lnet = nn.Sequential(# 输入图像和输出图像都是28×28,因此我们要先进行填充2格nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10)
)

上面的模型图示就为:
在这里插入图片描述
我们可以先检查模型,在每一层打印输出的形状:

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__, 'output shape:\t', X.shape)

输出结果:

Conv2d output shape: torch.Size([1, 6, 28, 28])
Sigmoid output shape: torch.Size([1, 6, 28, 28])
AvgPool2d output shape: torch.Size([1, 6, 14, 14])
Conv2d output shape: torch.Size([1, 16, 10, 10])
Sigmoid output shape: torch.Size([1, 16, 10, 10])
AvgPool2d output shape: torch.Size([1, 16, 5, 5])
Flatten output shape: torch.Size([1, 400])
Linear output shape: torch.Size([1, 120])
Sigmoid output shape: torch.Size([1, 120])
Linear output shape: torch.Size([1, 84])
Sigmoid output shape: torch.Size([1, 84])
Linear output shape: torch.Size([1, 10])

模型训练

既然已经实现了LeNet,现在可以查看它在Fashion-MNIST数据集上的表现:

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

计算成本较高,因此使用GPU来加快训练。为了进行评估,对之前的evaluate_accuracy进行修改,由于完整的数据集位于内存中,因此在模型使用GPU计算数据集之前,我们需要将其复制到显存中。

def evaluate_accuracy_gpu(net, data_iter, device=None):"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需(后面内容)else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

要使用GPU,我们要在正向和反向传播之前,将每一小批量数据移动到我们GPU上。
如下所示的train_ch6类似于之前定义的train_ch3。以下训练函数假定从高级API创建的模型作为输入,并进行相应的优化。
使用Xavier来随机初始化模型参数。有关于Xavier的推导和原理可以看下面的文章:
机器学习&&深度学习——数值稳定性和模型化参数(详细数学推导)
与全连接层一样,使用交叉熵损失函数和小批量随机梯度下降,代码如下:

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):  #@save"""用GPU训练模型"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc =  metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i+1) / num_batches, (train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

此时我们可以开始训练和评估LeNet模型:

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
d2l.plt.show()

运行输出(这边我没有用远程的GPU,在自己本地跑了,本地只有CPU):

training on cpu
loss 0.477, train acc 0.820, test acc 0.795
8004.2 examples/sec on cpu

运行图片:
在这里插入图片描述

小结

1、卷积神经网络(CNN)是一类使用卷积层的网络
2、在卷积神经网络中,我们组合使用卷积层、非线性激活函数和池化层
3、为了构造高性能的卷积神经网络,我们通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数
4、传统卷积神经网络中,卷积块编码得到的表征在输出之前需要由一个或多个全连接层进行处理

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/26351.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python Opencv实践 - 基本图像IO操作

import numpy as np import cv2 as cv import matplotlib.pyplot as plt#读取图像 #cv2.IMREAD_COLOR: 读取彩色图像,忽略alpha通道,也可以直接写1 #cv2.IMREAD_GRAYSCALE: 读取灰度图,也可以直接写0 #cv2.IMREAD_UNCHANGED: 读取…

C高级【day4】

思维导图: 写一个函数,获取用户的uid和gid并使用变量接收: #!/bin/bashfunction get_uid {my_uidid -umy_gidid -g }get_uid echo "当前用户的UID:$my_uid" echo "当前用户的GID:$my_gid"整理冒泡…

论文代码学习—HiFi-GAN(4)——模型训练函数train文件具体解析

文章目录 引言正文模型训练代码整体训练过程具体训练细节具体运行流程 多GPU编程main函数(通用代码)完整代码 总结引用 引言 这里翻译了HiFi-GAN这篇论文的具体内容,具体链接。这篇文章还是学到了很多东西,从整体上说&#xff0c…

FPGA学习——Altera IP核调用之PLL篇

文章目录 一、IP核1.1 IP核简介1.2 FPGA中IP核的分类1.3 IP核的缺陷 二、PLL简介2.1 什么是PLL2.2 PLL结构图2.3 C4开发板上PLL的位置 三、IP核调用步骤四、编写测试代码五、总结 一、IP核 1.1 IP核简介 IP核(知识产权核),是在集成电路的可…

8-7 homework

1.思维导图 2.写一个函数&#xff0c;获取用户的uid和gid并使用变量接收 3.bubble_sort #include <stdio.h>//先排好的都是放在最后的&#xff0c;所以for的内层限制条件是不把后面的计算在内的&#xff0c;内层只循环前面的 int main(){int a [10]{11,42,3,24,65,16,73…

数据结构——红黑树基础(博文笔记)

数据结构在查找这一章里介绍过这些数据结构&#xff1a;BST&#xff0c;AVL&#xff0c;RBT&#xff0c;B和B。 除去RBT&#xff0c;其他的数据结构之前的学过&#xff0c;都是在BST的基础上进行微小的限制。 1.比如AVL是要求任意节点的左右子树深度之差绝对值不大于1,由此引出…

centos7 yum源安装出错及更新问题

如下 首先&#xff0c;在搜索jdk时报错如下&#xff1a; 解决办法 1、进入 yum的repo目录 cd /etc/yum.repos.d/2、修改所有的CentOS文件内容 sed -i s/mirrorlist/#mirrorlist/g /etc/yum.repos.d/CentOS-*sed -i s|#baseurlhttp://mirror.centos.org|baseurlhttp://vau…

Report Sharp-Shooter Lite Edition Crack

Report Sharp-Shooter Lite Edition Crack 报告Sharp Shooter™ 是为.NET Framework设计的&#xff0c;使用C#编写&#xff0c;并且只包含100%的托管代码。Report Sharp Shooter能够从多个数据源生成任何复杂的报告&#xff0c;并将生成的报告导出为大多数格式&#xff0c;包括…

机器学习笔记之优化算法(八)简单认识Wolfe Condition的收敛性证明

机器学习笔记之优化算法——简单认识Wolfe Condition收敛性证明 引言回顾&#xff1a; Wolfe \text{Wolfe} Wolfe准则准备工作推导条件介绍推导结论介绍 关于 Wolfe \text{Wolfe} Wolfe准则收敛性证明的推导过程 引言 上一节介绍了非精确搜索方法—— Wolfe \text{Wolfe} Wolf…

Wavefront .OBJ文件格式解读【3D】

OBJ&#xff08;或 .OBJ&#xff09;是一种几何定义文件格式&#xff0c;最初由 Wavefront Technologies 为其高级可视化器动画包开发。 该文件格式是开放的&#xff0c;已被其他 3D 图形应用程序供应商采用。 OBJ 文件格式是一种简单的数据格式&#xff0c;仅表示 3D 几何体&…

node.js安装

下载 https://nodejs.org/en 安装 D:\Program Files\nodejs 配置 D:\Program Files\nodejs 目录下新建 node_cache 和 node_global 在cmd管理员身份运行&#xff1a; npm config set prefix "D:\Program Files\nodejs\node_global" npm config set cache &qu…

算法通关村第五关——HashMap和队列问题分析

1.HashMap 1.1Hash的概念和基本特征 哈希(Hash)&#xff1a;也称为散列。就是把任意长度的输入&#xff0c;通过散列算法&#xff0c;变换成固定长度的输出&#xff0c;这个输出值就是散列值。 假设数组array存放的是1到15这些数&#xff0c;现在要存在一个大小是7的Hash表中…

Android 刷新与显示

目录 屏幕显示原理&#xff1a; 显示刷新的过程 VSYNC机制具体实现 小结&#xff1a; 屏幕显示原理&#xff1a; 过程描述&#xff1a; 应用向系统服务申请buffer 系统服务返回一个buffer给应用 应用开始绘制&#xff0c;绘制完成就提交buffer&#xff0c;系统服务把buffer数据…

7_分类算法—逻辑回归

文章目录 逻辑回归&#xff1a;1 Logistic回归&#xff08;二分类问题&#xff09;1.1 sigmoid函数1.2 Logistic回归及似然函数&#xff08;求解&#xff09;1.3 θ参数求解1.4 Logistic回归损失函数1.5 LogisticRegression总结 2 Softmax回归&#xff08;多分类问题&#xff0…

Oracle安装与配置

一 把windows2003拖到vm里去 1.1 加一块虚拟网卡 1.2 把网卡信息改一下 配完之后点一下应用 这个地方改成一个不是1但是在255之内的数&#xff0c;就可以 二 后面调试很久&#xff0c;失败了 后面还有一些步骤&#xff0c;但是总的来说&#xff0c;这次安装失败了&#xf…

适合自己企业的erp系统怎么选?这8条关键因素缺一不可!

一文看懂&#xff1a;如何选择适合自己企业的ERP系统&#xff1f;选型过程中有哪些关键因素需要考虑&#xff1f; 无论你是多大规模的企业&#xff0c;看懂这一篇&#xff0c;你都能受用无穷。 哪怕你需求复杂&#xff0c;现成ERP系统无法满足&#xff0c;最后我也给出了一条…

政府大数据资源中心建设总体方案[56页PPT]

导读&#xff1a;原文《政府大数据资源中心建设总体方案[56页PPT]》&#xff08;获取来源见文尾&#xff09;&#xff0c;本文精选其中精华及架构部分&#xff0c;逻辑清晰、内容完整&#xff0c;为快速形成售前方案提供参考。 完整版领取方式 完整版领取方式&#xff1a; 如需…

Zebec Protocol ,不止于 Web3 世界的 “Paypal”

Paypal是传统支付领域的巨头企业&#xff0c;在北美支付市场占有率约为77%以上。从具体的业务数据看&#xff0c;在8月初&#xff0c;Paypal公布的2023年第二季度财报显示&#xff0c;PayPal第二季度净营收为73亿美元&#xff0c;净利润为10.29亿美元。虽然Paypal的净利润相交去…

javaWeb项目--二级评论完整思路

先来看前端需要什么吧&#xff1a; 通过博客id&#xff0c;首先需要显示所有一级评论&#xff0c;包括评论者的头像&#xff0c;昵称&#xff0c;评论时间&#xff0c;评论内容 然后要显示每个一级评论下面的二级评论&#xff0c;包括&#xff0c;评论者的头像&#xff0c;昵称…

6.s081/6.1810(Fall 2022)Lab5: Copy-on-Write Fork for xv6

前言 本来往年这里还有个Lazy Allocation的&#xff0c;今年不知道为啥直接给跳过去了。. 其他篇章 环境搭建 Lab1: Utilities Lab2: System calls Lab3: Page tables Lab4: Traps Lab5: Copy-on-Write Fork for xv6 参考链接 官网链接 xv6手册链接&#xff0c;这个挺重要…