李沐23_LeNet——自学笔记

手写的数字识别

知名度最高的数据集:MNIST
1.训练数据:50000

2.测试数据:50000

3.图像大小:28✖28

4.10类

总结

1.LeNet是早期成功的神经网络

2.先使用卷积层来学习图片空间信息

3.使用全连接层来转换到类别空间

代码实现

LeNet由两部分组成:卷积编码器和全连接层密集块

import torch
from torch import nn
from d2l import torch as d2lclass Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28) # 原图:28✖28,填充后是32✖32net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(), # 6个通道nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

我们将一个大小为28✖28的单通道(黑白)图像通过LeNet。

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])

在整个卷积块中,与上一层相比,每一层特征的高度和宽度都减小了。

第一个卷积层使用2个像素的填充,来补偿5✖5卷积核导致的特征减少。

相反,第二个卷积层没有填充,因此高度和宽度都减少了4个像素。

随着层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个。

同时,每个汇聚层的高度和宽度都减半。最后,每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。

LeNet在Fashion-MNIST数据集上的表现

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
# 下载fashion_MNIST数据集
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz100%|██████████| 26421880/26421880 [00:02<00:00, 9258920.33it/s] Extracting ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz100%|██████████| 29515/29515 [00:00<00:00, 171125.91it/s]Extracting ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz100%|██████████| 4422102/4422102 [00:01<00:00, 3169968.34it/s]Extracting ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz100%|██████████| 5148/5148 [00:00<00:00, 4336669.41it/s]Extracting ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.warnings.warn(_create_warning_msg(
def evaluate_accuracy_gpu(net, data_iter, device=None): #计算模型精度"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

训练函数

1.训练函数train_ch6也类似于3.6节中定义的train_ch3。

2.使用高级API创建的模型作为输入,并进行相应的优化。

3.Xavier随机初始化模型参数。
4.使用交叉熵损失函数和小批量随机梯度下降。

#用GPU训练模型,比第三章多了device
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], # 动画效果legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

训练和评估LeNet-5模型。

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.461, train acc 0.827, test acc 0.793
35664.6 examples/sec on cuda:0

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/804491.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【oracle数据库安装篇一】Linux5.6基于LVM安装oracle10gR2单机

说明 本篇文章主要介绍了Linux5.6基于LVM安装oracle10gR2单机的配置过程&#xff0c;比较详细&#xff0c;基本上每一个配置部分的步骤都提供了完整的脚本&#xff0c;安装部分都提供了简单的说明和截图&#xff0c;帮助你100%安装成功oracle数据库。 安装过程有不明白的地方…

二维相位解包理论算法和软件【全文翻译- DCT相位解包裹(5.3.2)】

5.3.2 基于 DCT 的方法 在本节中,我们将详细介绍如何通过 DCT 算法解决非加权最小二乘相位解缠问题,而不是通过FFT.我们将使用公式 5.53 所定义的二维余弦变换。我们开发的算法等同于 FFT 方法 2(第 5.3.1 节)。与 FFT 方法 I 等价的 DCT 算法也可以推导出来,但我们将其作…

PlayerSettings.WebGL.emscriptenArgs设置无效的问题

1&#xff09;PlayerSettings.WebGL.emscriptenArgs设置无效的问题 2&#xff09;多个小资源包合并为大资源包的疑问 3&#xff09;AssetBundle在移动设备上丢失 4&#xff09;Unity云渲染插件RenderStreaming&#xff0c;如何实现多用户分别有独立的操作 这是第381篇UWA技术知…

Meta 的 Llama 模型系列即将迎来第三次大更新

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

linux启动流程(s3c2400)

概述 大致流程&#xff1a;内核&#xff08;kernel&#xff09;都是由bootloader程序引导启动的&#xff0c;所以我们应该先烧进去bootloader程序。然后可以通过保存的内核代码或者通过远程连接&#xff08;nfs/tftp&#xff09;的主机下载再运行&#xff0c;再挂载根文件系统。…

ppt从零基础到高手【办公】

第一章&#xff1a;文字排版篇01演示文稿内容基密02文字操作规范03文字排版处理04复习&作业解析第二章&#xff1a;图形图片图表篇05图形化表达06图片艺术化07轻松玩转图表08高效工具&母版统一管理09复习&作业解析10轻松一刻-文字图形小技巧速学第三章&#xff1a;…

SWM341系列应用(RTC、FreeRTOS\RTTHREAD应用和Chip ID)

SWM341系列RTC应用 22.1、RTC的时钟基准 --liuzc 2023-8-17 现象:客户休眠发现RTC走的不准&#xff0c;睡眠2小时才走了5分钟。 分析与解决&#xff1a;经过排查RTC的时钟源是XTAL_32K&#xff0c;由于睡眠时时设置XTAL->CR0&#xff1b;&#xff0c;会把XTAL_32K给关…

C语言:指针详解(1)

目录 一、内存和地址 1.内存 2.究竟该如何理解编址 二、指针变量和地址 1.取地址操作符(&) 2.解引用操作符(*) 3.指针变量的大小 三、指针变量类型的意义 1.指针的解引用 2.指针-整数 3.void*指针 四、const修饰指针 1.const修饰变量 2.const修饰指针变量 五…

【ARM Coresight SOC-600 -- ETF 介绍】

请阅读【ARM Coresight SoC-400/SoC-600 专栏导读】 文章目录 SOC ETFSOC ETF REGISTERScss600_tmc_etf RAM Read Data register(RRD)css600_tmc_etf RAM Read Pointer register(RRP)css600_tmc_etf RAM Write Pointer register(RWP)css600_tmc_etf RAM Write Data regis…

公开课学习——仿抖音直播平台

文章目录 直播抖音的直播原理Java继承直播客户端工具&#xff1a; ffmpeg客户端和网页集成CDN网络——性能提升关键——边缘计算 实时聊天——IM系统怎么实现&#xff1f;——websocketIM系统消息如何转发&#xff1f;直播场景IM系统是什么样子&#xff1f; 直播 抖音的直播原…

安全操作代码优化思路

理论依据 数据增强和样本选择 在训练阶段&#xff0c;您可以考虑添加数据增强来提升模型的鲁棒性和泛化能力。针对人脸检测任务&#xff0c;可以尝试以下改进&#xff1a; 对输入图像进行随机裁剪、缩放、旋转、翻转等数据增强操作&#xff0c;以增加数据的多样性。 使用难样…

操作系统—修改xv6内核调度算法

文章目录 修改xv6内核调度算法1.实验环境2.基于优先级的调度算法(1).基本实现思路(2).实现流程(3).一些问题 3.乐透调度算法(1).思路(2).实现流程(3).一些问题 总结参考资料 修改xv6内核调度算法 1.实验环境 这一次的实验因为是在xv6内核中实现一些调度算法&#xff0c;因此我…

Flutter入门指南

文章目录 一、环境搭建二、基本概念三、创建一个简单的Flutter应用四、常用组件及代码示例五、总结推荐阅读 笔者项目中使用Flutter的模块并不多。虽然笔者还没有机会在项目中正式使用Flutter&#xff0c;但是也在学习Flutter的一些基本用法。本文就是一篇Flutter的入门介绍&am…

浏览器滚动条样式终极方案

首先各个浏览器滚动条保持统一是不可能的&#xff0c;因为浏览器不支持大多数滚动条样式属性 从支持可调整的角度来看&#xff0c;我们一般选择 保持chrome样式&#xff0c;其他浏览器样式使用默认效果保持chrome、火狐样式一致&#xff0c;其他浏览器样式使用默认效果 所以这…

C++智能指针2——unique_ptr和weak_ptr

unique_ptr 一个unique_ptr“拥有”它所指向的对象。 与shared_ptr不同&#xff0c;某个时刻只能有一个unique_ptr指向一个给定对象。 当unique_ptr被销毁时&#xff0c;它所指向的对象也被销毁。 和shared_ptr 不同&#xff0c;没有类似make_shared的标准库函数返回一个un…

【双指针】两数之和|| 输入有序数组

两数之和|| 输入有序数组 给你一个下标从 1 开始的整数数组 numbers &#xff0c;该数组已按 非递减顺序排列 &#xff0c;请你从数组中找出满足相加之和等于目标数 target 的两个数。如果设这两个数分别是 numbers[index1] 和 numbers[index2] &#xff0c;则 1 < index1 …

在Linux系统上实现TCP(socket)通信

一.什么TCP TCP&#xff08;传输控制协议&#xff09;是一种面向连接的、可靠的、基于字节流的传输层通信协议。 二.TCP通信流程 三. TCP 服务器端 1 创建socket int sockfd socket(AF_INET, SOCK_STREAM, 0); //SOCK_STREAM tcp通信2 绑定(bind) struct sockaddr_in myad…

Ubuntu下无法获得锁 / 检测到系统程序错误 / E: Could not get lock /var/lib/apt/lists/lock

这里写自定义目录标题 Ubuntu下无法获得锁 错误 / E: Could not get lock /var/lib/apt/lists/lock Ubuntu下无法获得锁 错误 / E: Could not get lock /var/lib/apt/lists/lock 1、E: Could not get lock /var/lib/apt/lists/lock - open (11: Recource temporarily unavaila…

【双指针】删除有序数组中的重复项Ⅱ

给你一个有序数组 nums &#xff0c;请你 原地 删除重复出现的元素&#xff0c;使得出现次数超过两次的元素只出现两次 &#xff0c;返回删除后数组的新长度。 不要使用额外的数组空间&#xff0c;你必须在 原地 修改输入数组 并在使用 O(1) 额外空间的条件下完成 示例 1&…

C++实现幻方实验

我们这个实验目的是实现大于2的奇数的n阶幻方 根据上述的例子我们可以看到一些规律&#xff0c;显示1放在最上方中间的位置&#xff0c;然后向右上方延申&#xff0c;在达到n这个数字时&#xff0c;停止延申&#xff0c;然后在n的下方开始n1的新一轮延申。明白了原理之后就很容…