李沐23_LeNet——自学笔记

手写的数字识别

知名度最高的数据集:MNIST
1.训练数据:50000

2.测试数据:50000

3.图像大小:28✖28

4.10类

总结

1.LeNet是早期成功的神经网络

2.先使用卷积层来学习图片空间信息

3.使用全连接层来转换到类别空间

代码实现

LeNet由两部分组成:卷积编码器和全连接层密集块

import torch
from torch import nn
from d2l import torch as d2lclass Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28) # 原图:28✖28,填充后是32✖32net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(), # 6个通道nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

我们将一个大小为28✖28的单通道(黑白)图像通过LeNet。

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])

在整个卷积块中,与上一层相比,每一层特征的高度和宽度都减小了。

第一个卷积层使用2个像素的填充,来补偿5✖5卷积核导致的特征减少。

相反,第二个卷积层没有填充,因此高度和宽度都减少了4个像素。

随着层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个。

同时,每个汇聚层的高度和宽度都减半。最后,每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。

LeNet在Fashion-MNIST数据集上的表现

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
# 下载fashion_MNIST数据集
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz100%|██████████| 26421880/26421880 [00:02<00:00, 9258920.33it/s] Extracting ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz100%|██████████| 29515/29515 [00:00<00:00, 171125.91it/s]Extracting ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz100%|██████████| 4422102/4422102 [00:01<00:00, 3169968.34it/s]Extracting ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz100%|██████████| 5148/5148 [00:00<00:00, 4336669.41it/s]Extracting ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.warnings.warn(_create_warning_msg(
def evaluate_accuracy_gpu(net, data_iter, device=None): #计算模型精度"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

训练函数

1.训练函数train_ch6也类似于3.6节中定义的train_ch3。

2.使用高级API创建的模型作为输入,并进行相应的优化。

3.Xavier随机初始化模型参数。
4.使用交叉熵损失函数和小批量随机梯度下降。

#用GPU训练模型,比第三章多了device
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], # 动画效果legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

训练和评估LeNet-5模型。

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.461, train acc 0.827, test acc 0.793
35664.6 examples/sec on cuda:0

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/804491.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【oracle数据库安装篇一】Linux5.6基于LVM安装oracle10gR2单机

说明 本篇文章主要介绍了Linux5.6基于LVM安装oracle10gR2单机的配置过程&#xff0c;比较详细&#xff0c;基本上每一个配置部分的步骤都提供了完整的脚本&#xff0c;安装部分都提供了简单的说明和截图&#xff0c;帮助你100%安装成功oracle数据库。 安装过程有不明白的地方…

二维相位解包理论算法和软件【全文翻译- DCT相位解包裹(5.3.2)】

5.3.2 基于 DCT 的方法 在本节中,我们将详细介绍如何通过 DCT 算法解决非加权最小二乘相位解缠问题,而不是通过FFT.我们将使用公式 5.53 所定义的二维余弦变换。我们开发的算法等同于 FFT 方法 2(第 5.3.1 节)。与 FFT 方法 I 等价的 DCT 算法也可以推导出来,但我们将其作…

PlayerSettings.WebGL.emscriptenArgs设置无效的问题

1&#xff09;PlayerSettings.WebGL.emscriptenArgs设置无效的问题 2&#xff09;多个小资源包合并为大资源包的疑问 3&#xff09;AssetBundle在移动设备上丢失 4&#xff09;Unity云渲染插件RenderStreaming&#xff0c;如何实现多用户分别有独立的操作 这是第381篇UWA技术知…

Meta 的 Llama 模型系列即将迎来第三次大更新

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

linux启动流程(s3c2400)

概述 大致流程&#xff1a;内核&#xff08;kernel&#xff09;都是由bootloader程序引导启动的&#xff0c;所以我们应该先烧进去bootloader程序。然后可以通过保存的内核代码或者通过远程连接&#xff08;nfs/tftp&#xff09;的主机下载再运行&#xff0c;再挂载根文件系统。…

ppt从零基础到高手【办公】

第一章&#xff1a;文字排版篇01演示文稿内容基密02文字操作规范03文字排版处理04复习&作业解析第二章&#xff1a;图形图片图表篇05图形化表达06图片艺术化07轻松玩转图表08高效工具&母版统一管理09复习&作业解析10轻松一刻-文字图形小技巧速学第三章&#xff1a;…

SWM341系列应用(RTC、FreeRTOS\RTTHREAD应用和Chip ID)

SWM341系列RTC应用 22.1、RTC的时钟基准 --liuzc 2023-8-17 现象:客户休眠发现RTC走的不准&#xff0c;睡眠2小时才走了5分钟。 分析与解决&#xff1a;经过排查RTC的时钟源是XTAL_32K&#xff0c;由于睡眠时时设置XTAL->CR0&#xff1b;&#xff0c;会把XTAL_32K给关…

C语言:指针详解(1)

目录 一、内存和地址 1.内存 2.究竟该如何理解编址 二、指针变量和地址 1.取地址操作符(&) 2.解引用操作符(*) 3.指针变量的大小 三、指针变量类型的意义 1.指针的解引用 2.指针-整数 3.void*指针 四、const修饰指针 1.const修饰变量 2.const修饰指针变量 五…

公开课学习——仿抖音直播平台

文章目录 直播抖音的直播原理Java继承直播客户端工具&#xff1a; ffmpeg客户端和网页集成CDN网络——性能提升关键——边缘计算 实时聊天——IM系统怎么实现&#xff1f;——websocketIM系统消息如何转发&#xff1f;直播场景IM系统是什么样子&#xff1f; 直播 抖音的直播原…

安全操作代码优化思路

理论依据 数据增强和样本选择 在训练阶段&#xff0c;您可以考虑添加数据增强来提升模型的鲁棒性和泛化能力。针对人脸检测任务&#xff0c;可以尝试以下改进&#xff1a; 对输入图像进行随机裁剪、缩放、旋转、翻转等数据增强操作&#xff0c;以增加数据的多样性。 使用难样…

操作系统—修改xv6内核调度算法

文章目录 修改xv6内核调度算法1.实验环境2.基于优先级的调度算法(1).基本实现思路(2).实现流程(3).一些问题 3.乐透调度算法(1).思路(2).实现流程(3).一些问题 总结参考资料 修改xv6内核调度算法 1.实验环境 这一次的实验因为是在xv6内核中实现一些调度算法&#xff0c;因此我…

在Linux系统上实现TCP(socket)通信

一.什么TCP TCP&#xff08;传输控制协议&#xff09;是一种面向连接的、可靠的、基于字节流的传输层通信协议。 二.TCP通信流程 三. TCP 服务器端 1 创建socket int sockfd socket(AF_INET, SOCK_STREAM, 0); //SOCK_STREAM tcp通信2 绑定(bind) struct sockaddr_in myad…

C++实现幻方实验

我们这个实验目的是实现大于2的奇数的n阶幻方 根据上述的例子我们可以看到一些规律&#xff0c;显示1放在最上方中间的位置&#xff0c;然后向右上方延申&#xff0c;在达到n这个数字时&#xff0c;停止延申&#xff0c;然后在n的下方开始n1的新一轮延申。明白了原理之后就很容…

计算机专业,不擅长打代码,考研该怎么选择?

考研其实和你的代码能力关系不大 所以在选学校以前可以看看有哪些学校复试是要求上机撸代码的&#xff0c;可能会要求比较严 初试真的不用担心代码问题&#xff0c;我也是基本零编程能力就开始备考考研的... 本人双非科班出身备考408成功上岸&#xff0c;在这里也想给想考40…

css面试题--定位与浮动

1、为什么需要清除浮动&#xff1f; 在非IE浏览器下&#xff0c;容器不设高度且子元素浮动时&#xff0c;容器高度不能被内容撑开&#xff0c;内容会溢出到容器外面而影响布局。这种现象被称为浮动。 浮动的原理&#xff1a;浮动元素脱离文档流&#xff0c;不占用空间&#xff…

使用 wangeditor 解析富文本并生成目录与代码块复制功能

在 Web 开发中&#xff0c;经常需要使用富文本编辑器来编辑和展示内容。wangeditor 是一个强大的富文本编辑器&#xff0c;提供了丰富的功能和灵活的配置&#xff0c;但是官方并没有提供目录导航和代码块的复制功能&#xff0c;所以我自己搞了一个 <template><div cla…

5个超好用的Python工具,赶紧码住!

Python开发软件可根据其用途不同分为两种&#xff0c;Python代码编辑器和Python集成开发工具&#xff0c;两者配合使用极大的提高Python开发人员的编程效率。掌握调试、语法高亮、Project管理、代码跳转、智能提示、自动完成、单元测试、版本控制等操作。 Python常用工具&…

小白新手学习 Python 使用哪个 Linux 系统更好?

对于小白新手学习Python&#xff0c;选择哪个Linux系统是一个很重要的问题&#xff0c;因为不同的Linux发行版&#xff08;distribution&#xff09;有着不同的特点、优势和适用场景。在选择时&#xff0c;需要考虑到易用性、学习曲线、社区支持等因素。 Ubuntu Ubuntu 是一个…

分布式系统中的唯一ID生成方法

通常在分布式系统中&#xff0c;有生成唯一ID的需求&#xff0c;唯一ID有多种实现方式。我们选择其中几种&#xff0c;简单阐述一下实现原理、适用场景、优缺点等信息。 目录 数据库多主复制UUID工单服务器雪花算法总结 数据库多主复制 数据库通常有自增属性&#xff0c;在单机…

CSS 实现无限波浪边框卡片

CSS 实现无限波浪边框卡片 效果展示 鼠标悬停效果&#xff0c;底部色块的边框是无限滚动的波浪 鼠标没有悬停效果 CSS 知识点 CSS 基础知识回顾使用 radial-gradient 实现波浪边框使用 anumate 属性实现波浪边框动画和控制动画运动 波浪实现原理 波浪边框的实现思路其…