《动手学深度学习(PyTorch版)》笔记6.3

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter6 Convolutional Neural Network(CNN)

6.3 LeNet

在这里插入图片描述

LeNet模型中每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层。虽然ReLU和最大汇聚层更有效,但它们在20世纪90年代还没有出现。每个卷积层使用 5 × 5 5\times 5 5×5卷积核和一个sigmoid激活函数,这些层将输入映射到多个二维特征输出,通常同时增加通道的数量。第一卷积层有6个输出通道,而第二个卷积层有16个输出通道,每个 2 × 2 2\times2 2×2池操作(步幅2)通过空间下采样将维数减少4倍。为了将卷积块的输出传递给稠密块,我们必须在小批量中展平每个样本。换言之,我们将这个四维输入转换成全连接层所期望的二维输入(第一个维度索引小批量中的样本,第二个维度给出每个样本的平面向量表示)。LeNet的稠密块有三个全连接层,分别有120、84和10个输出。因为我们在执行分类任务,所以输出层的10维对应于最后输出结果的数量。

import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as pltnet = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).devicemetric = d2l.Accumulator(2)#metric[0]:正确预测的数量,metric[1]:总预测的数量。with torch.no_grad():for X, y in data_iter:if isinstance(X, list):X = [x.to(device) for x in X]# BERT微调所需,将输入转移到GPU上(之后将介绍)else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):#@save"""用GPU训练模型"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):metric = d2l.Accumulator(3)#metric[i]分别是训练损失之和,训练准确率之和,样本数net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

训练结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/663971.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python实现PDF到HTML的转换

PDF文件是共享和分发文档的常用选择,但提取和再利用PDF文件中的内容可能会非常麻烦。而利用Python将PDF文件转换为HTML是解决此问题的理想方案之一,这样做可以增强文档可访问性,使文档可搜索,同时增强文档在不同场景中的实用性。此…

后端——go系统学习笔记(不断更新中......)

数组 固定大小 初始化 arr1 : [3]int{1, 2, 3} arr2 : [...]int{1, 2, 3} var arr3 []int var arr4 [4]int切片 长度是动态的 初始化 arr[0:3] slice : []int{1,2,3} slice : make([]int, 10)len和cap len是获取切片、数组、字符串的长度——元素的个数cap是获取切片的容量—…

深度学习入门笔记(七)卷积神经网络CNN

我们先来总结一下人类识别物体的方法: 定位。这一步对于人眼来说是一个很自然的过程,因为当你去识别图标的时候,你就已经把你的目光放在了图标上。虽然这个行为不是很难,但是很重要。看线条。有没有文字,形状是方的圆的,还是长的短的等等。看细节。纹理、颜色、方向等。卷…

Java正则表达式之Pattern和Matcher

目录 前言一、Pattern和Matcher的简单使用二、Pattern详解2.1 Pattern 常用方法2.1.1 compile(String regex)2.1.2 matches(String regex, CharSequence input)2.1.3 split(CharSequence input)2.1.4 pattern()2.1.5 matcher(CharSequence input) 三、Matcher详解3.1 Matcher 常…

JSP和JSTL板块:第三节 JSP四大域对象 来自【汤米尼克的JAVAEE全套教程专栏】

JSP和JSTL板块:第三节 JSP四大域对象 一、page范围二、request范围三、session范围四、application范围 在服务器和客户端之间、各个网页之间、哪怕同一个网页之内,总是需要传递各种参数值,这时JSP的内置对象就是传递这些参数的载具。内置对象…

​(四)hive的搭建2

在&#xff08;三&#xff09;hive的搭建1中我们搭建好了hive环境&#xff0c;但是只能本地访问&#xff0c;在本节中配置Hive的访问方式。 1.元数据服务的方式 1.1 编辑hive-site.xml sudo vi hive-site.xml 在文件最后增加以下内容 <!– 指定存储元数据要连接的地址 –…

无里程计下的纯跟踪算法实现

文章目录 概要生成相机坐标系下的三维坐标无里程计下的纯跟踪算法实现 概要 当你只有一个相机的时候&#xff0c;想要快速实现机器人跟随功能&#xff0c;没有里程计的情况下&#xff0c;就可以看这里了。这篇博文实现了一个无里程计下的纯跟踪算法。 生成相机坐标系下的三维…

1、安全开发-Python爬虫EDUSRC目标FOFA资产Web爬虫解析库

用途&#xff1a;个人学习笔记&#xff0c;有所借鉴&#xff0c;欢迎指正 前言&#xff1a; 主要包含对requests库和Web爬虫解析库的使用&#xff0c;python爬虫自动化&#xff0c;批量信息收集 Python开发工具&#xff1a;PyCharm 2022.1 激活破解码_安装教程 (2022年8月25日…

Linux下find命令详解

find #查找文件 #按照文件名、大小、时间、权限、类型、所属者、所属组来搜索文件 格式&#xff1a; find 查找路径 查找条件 具体条件&#xff08;按文件名或时间大小等&#xff09; 操作 注意&#xff1a; find命令默认的操作是print输出 find是检索…

MATLAB绘制电磁场

MATLAB绘制电磁场举例: clc;close all;clear all;warning off;%清除变量 rand(seed, 100); randn(seed, 100); format long g; m12 for k1:m for j1:m if k1 V(j,k)1; elseif((j1)|(jm)|(km)) V(j,k)0; else …

PKG系统安装包及IPSW固件:MacOS 11-14 Sonoma 正式版

MacOS 14 Sonoma&#xff0c;为提高生产力和创造力带来了全新的功能&#xff0c;有了更多使用小部件和令人惊叹的新屏幕保护程序进行个性化设置的方法&#xff0c;对Safari浏览器和视频会议进行了重大更新&#xff0c;以及优化的游戏体验——Mac体验比以往任何时候都更好。 mac…

贝叶斯的缺点

贝叶斯方法是一种统计学习方法&#xff0c;通过利用贝叶斯定理来计算给定先验概率的情况下&#xff0c;后验概率的条件概率。虽然贝叶斯方法在许多领域中应用广泛且有效&#xff0c;但也存在一些缺点。以下是一些贝叶斯方法的缺点的例子&#xff1a; 1、先验概率的选择 贝叶斯方…

第二十四天| 77. 组合

Leetcode 77. 组合 题目链接&#xff1a;77 组合 题干&#xff1a;给定两个整数 n 和 k&#xff0c;返回范围 [1, n] 中所有可能的 k 个数的组合。你可以按 任何顺序 返回答案。 思考&#xff1a;回溯法。把回溯法的搜索过程抽象为树形结构。 每次从集合中选取元素&#xff0…

vio参数文件内相机imu参数的修改

imu标定工具 https://github.com/mintar/imu_utils网络上有各种IMU校准工具和校准教程&#xff0c;曾经花费了巨大精力跟着各种教程去跑校准。 然而&#xff0c;标定使用的数据都是在静止状态下录制的&#xff0c;我们在使用vio或者imu-cam联合标定的时候&#xff0c;imu确是处…

计算机视觉实战项目4(单目测距与测速+摔倒检测+目标检测+目标跟踪+姿态识别+车道线识别+车牌识别+无人机检测+A_路径规划+行人车辆计数+动物识别等)

基于YOLOv5的无人机视频检测与计数系统 摘要&#xff1a; 无人机技术的快速发展和广泛应用给社会带来了巨大的便利&#xff0c;但也带来了一系列的安全隐患。为了实现对无人机的有效管理和监控&#xff0c;本文提出了一种基于YOLOv5的无人机视频检测与计数系统。该系统通过使用…

AJAX-认识URL

定义 概念&#xff1a;URL就是统一资源定位符&#xff0c;简称网址&#xff0c;用于访问网络上的资源 组成 协议 http协议&#xff1a;超文本传输协议&#xff0c;规定浏览器和服务器之间传输数据的格式&#xff1b;规定了浏览器发送及服务器返回内容的格式 协议范围&#xf…

flask基于Python的期货交易模拟系统的django-afl61-vue

期货交易模拟系统是一个便于用户在线查看期货投资、取消投资、风险控制、账户资金、持仓资金等&#xff0c;管理员进行管理的平台。因此本文主要论述了系统开发的过程和实现的功能&#xff0c;结合Web技术来实现的期货交易模拟系统。本系统以软件工程理论为开发基础&#xff0c…

UE4 C++ 静态加载类和资源

静态加载类和资源&#xff1a;指在编译时加载&#xff0c;并且只能在构造函数中编写代码 .h //增加所需组件的头文件 #include "Components/SceneComponent.h" //场景组件 #include "Components/StaticMeshComponent.h" //静态网格体组件 #include &qu…

SpringBoot实战2

目录 1.如何返回两个类型的数据&#xff1f;User和Booth 2.如何使用MyBatis遍历一个数组进行查询&#xff1f; 3.前端要的数据太多太杂&#xff0c;我们拼接多个List&#xff0c;前端找数据困难&#xff0c;浪费时间。因此我们进行三表联表查询。 1.首先创建一个vo包&#x…

yo!这里是c++IO流相关介绍

目录 前言 C语言的输入输出 CIO流基本介绍 流的概念 IO流类库 iostream fstream stringstream 后记 前言 学过C语言的输入输出相关知识点的童鞋应该多多少少会觉得有些许麻烦&#xff0c;反正我就是这么觉得的&#xff0c;scanf、printf等函数不仅数量众多&#xff0c…